/**
* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef MINDIE_LLM_INFERENCE_REQUEST_ID_H
#define MINDIE_LLM_INFERENCE_REQUEST_ID_H
#include <functional>
#include <string>
namespace mindie_llm {
/// class InferRequestId
///
/// This class is used to generate inference request id, which can be either a uint64_t or a string.
/// It also implements the custom comparison operator to compare two InferRequestId objects and hash function
/// to enable proper storage and retrieval of InferRequestId objects in hash-based containers.
class InferRequestId {
public:
enum class DataType { UINT64, STRING };
/// The default constructor initializes the requestLabel_ to an empty string,
/// requestIndex_ to 0, and idType_ to UINT64.
explicit InferRequestId() : requestLabel_(""), requestIndex_(0), idType_(InferRequestId::DataType::UINT64) {}
/// The constructor with a string parameter initializes the requestLabel_ to the given string,
/// requestIndex_ to 0, and idType_ to STRING.
///
/// \param requestLabel The string is used to initialize requestLabel_.
explicit InferRequestId(std::string requestLabel)
: requestLabel_(std::move(requestLabel)), requestIndex_(0), idType_(InferRequestId::DataType::STRING) {}
/// The constructor with a uint64_t parameter initializes the requestLabel_ to an empty string,
/// requestIndex_ to the given value, and idType_ to UINT64.
///
/// \param requestIndex The uint64_t value to used to initialize requestIndex_.
explicit InferRequestId(uint64_t requestIndex)
: requestIndex_(requestIndex), idType_(InferRequestId::DataType::UINT64) {}
/// The assignment operator with a uint64_t parameter sets the requestLabel_ to an empty string,
/// requestIndex_ to the given value, and idType_ to UINT64.
///
/// \param rhs The uint64_t value is used to assign to requestIndex_.
InferRequestId &operator=(const uint64_t rhs) {
requestLabel_ = "";
requestIndex_ = rhs;
idType_ = InferRequestId::DataType::UINT64;
return *this;
}
/// The assignment operator with a string parameter sets the requestLabel_ to the given value, requestIndex_ to 0,
/// and idType_ to STRING.
///
/// \param rhs The string value is used to assign to requestLabel_.
InferRequestId &operator=(const std::string &rhs) {
requestLabel_ = rhs;
requestIndex_ = 0;
idType_ = InferRequestId::DataType::STRING;
return *this;
}
/// The assignment operator with a const InferRequestId parameter sets the requestLabel_ and
/// requestIndex_ to the values of the given object, and idType_ to the value of the given object.
///
/// \param rhs The const InferRequestId object to assign to this object.
InferRequestId &operator=(const InferRequestId &rhs) {
if (this != &rhs) {
requestLabel_ = rhs.requestLabel_;
requestIndex_ = rhs.requestIndex_;
idType_ = rhs.idType_;
}
return *this;
}
/// The copy constructor creates a new InferRequestId object with the same values as the given object.
///
/// \param other The InferRequestId object to copy.
InferRequestId(const InferRequestId &other) {
requestLabel_ = other.requestLabel_;
requestIndex_ = other.requestIndex_;
idType_ = other.idType_;
}
/// The function returns the type of the request ID.
DataType Type() const { return idType_; }
/// The function returns the string value of the request ID.
const std::string &StringValue() const { return requestLabel_; }
/// The function returns the unsigned integer value of the request ID.
uint64_t UnsignedIntValue() const { return requestIndex_; }
/// The function returns the string representation of the request ID.
const std::string GetRequestIdString() const {
if (idType_ == InferRequestId::DataType::UINT64) {
return std::to_string(requestIndex_);
}
return requestLabel_;
}
/// The struct Compare is used to compare two InferRequestId objects.
///
/// \param lhs The first InferRequestId object to compare.
/// \param rhs The second InferRequestId object to compare.
struct Compare {
bool operator()(const InferRequestId &lhs, const InferRequestId &rhs) const {
if (lhs.Type() == InferRequestId::DataType::STRING) {
return std::hash<std::string>()(lhs.StringValue()) < std::hash<std::string>()(rhs.StringValue());
} else {
return lhs.UnsignedIntValue() < rhs.UnsignedIntValue();
}
}
};
private:
/// The equal operator is used to compare two InferRequestId objects.
///
/// \param lhs The first InferRequestId object to compare.
/// \param rhs The second InferRequestId object to compare.
/// \return true if the two objects are equal, false otherwise.
friend bool operator==(const InferRequestId lhs, const InferRequestId rhs) {
if (lhs.Type() == rhs.Type()) {
switch (lhs.Type()) {
case InferRequestId::DataType::STRING:
return lhs.StringValue() == rhs.StringValue();
case InferRequestId::DataType::UINT64:
return lhs.UnsignedIntValue() == rhs.UnsignedIntValue();
default:
return lhs.UnsignedIntValue() == rhs.UnsignedIntValue();
}
} else {
return false;
}
}
friend bool operator!=(const InferRequestId lhs, const InferRequestId rhs) { return !(lhs == rhs); }
std::string requestLabel_{}; /// The label of the request.
uint64_t requestIndex_{}; /// The index of the request.
DataType idType_; /// The type of the request ID.
};
} // namespace mindie_llm
/// Hash function for the InferRequestId class,
/// depending on the type of the request ID(either string or unsigned integer),
/// it will hash the string value or the unsigned integer value.
///
/// \param reqId The InferRequestId object to hash.
/// \return The hash value of the InferRequestId object.
namespace std {
template <>
struct hash<mindie_llm::InferRequestId> {
size_t operator()(const mindie_llm::InferRequestId &reqId) const {
if (reqId.Type() == mindie_llm::InferRequestId::DataType::STRING) {
return std::hash<std::string>()(reqId.StringValue());
} else {
return std::hash<uint64_t>()(reqId.UnsignedIntValue());
}
}
};
} // namespace std
#endif // MINDIE_LLM_INFERENCE_REQUEST_ID_H