* Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved.
*/
#ifndef OMNI_RUNTIME_DICTIONARY_CONTAINER_H
#define OMNI_RUNTIME_DICTIONARY_CONTAINER_H
#include <memory>
#include "large_string_container.h"
#include "memory/aligned_buffer.h"
namespace omniruntime::vec::unsafe {
class UnsafeDictionaryContainer;
}
namespace omniruntime::vec {
using namespace mem;
template <typename RAW_DATA_TYPE, template <typename> typename CONTAINER = LargeStringContainer>
class DictionaryContainer {
using DictType =
std::conditional_t<is_container_v<RAW_DATA_TYPE>, CONTAINER<RAW_DATA_TYPE>, AlignedBuffer<RAW_DATA_TYPE>>;
public:
DictionaryContainer(const int32_t *values, int32_t valueSize, std::shared_ptr<DictType> dictionary,
int32_t dictSize, int32_t dictOffset = 0)
: valueSize(valueSize),
dictionary(std::move(dictionary)),
dictSize(dictSize),
dictOffset(dictOffset),
isSliced(false)
{
this->values.reserve(valueSize);
for (int32_t i = 0; i < valueSize; i++) {
this->values[i] = values[i];
}
}
DictionaryContainer(std::vector<int32_t> &values, int32_t valueSize, std::shared_ptr<DictType> dictionary,
int32_t dictSize, int32_t dictOffset = 0)
: values(std::move(values)),
valueSize(valueSize),
dictionary(std::move(dictionary)),
dictSize(dictSize),
dictOffset(dictOffset),
isSliced(true)
{}
~DictionaryContainer() = default;
std::shared_ptr<DictionaryContainer<RAW_DATA_TYPE, CONTAINER>> CopyPositions(int32_t *positions, int32_t length)
{
std::vector<int32_t> newValues(length);
for (int32_t i = 0; i < length; i++) {
newValues[i] = values[positions[i]];
}
return std::make_shared<DictionaryContainer<RAW_DATA_TYPE, CONTAINER>>(newValues, length, dictionary, dictSize,
dictOffset);
}
ALWAYS_INLINE typename PARAM_TYPE<RAW_DATA_TYPE>::type GetValue(int32_t index)
{
return dictionary->GetValue(values[index] + dictOffset);
}
ALWAYS_INLINE void SetValue(int32_t index, RAW_DATA_TYPE &value)
{
for (int32_t i = 0; i < dictSize; i++) {
RAW_DATA_TYPE dicValue = dictionary->GetValue(i + dictOffset);
if (dicValue == value) {
values[index] = i;
return;
}
}
throw OmniException("OPERATOR_RUNTIME_ERROR", "setting to a value doesn't exist in the dictionary");
}
ALWAYS_INLINE int64_t GetContainerCapacity()
{
int64_t containerCapacity = sizeof(DictionaryContainer);
if (!isSliced) {
containerCapacity += valueSize * sizeof(int32_t);
}
return containerCapacity;
}
private:
friend class unsafe::UnsafeDictionaryContainer;
std::vector<int32_t> values;
int32_t valueSize;
std::shared_ptr<DictType> dictionary;
int32_t dictSize;
int32_t dictOffset;
bool isSliced;
};
}
#endif