#ifndef OPPLUGIN_UTILS_ATB_OPERATION_CREATE_H
#define OPPLUGIN_UTILS_ATB_OPERATION_CREATE_H
#include <unordered_map>
#include <mutex>
#include <memory>
#include <torch_npu/csrc/framework/OpCommand.h>
#include <torch_npu/csrc/core/npu/NPUGraphsUtils.h>
#include "op_plugin/third_party/atb/inc/atb_infer.h"
#include "OperationCacheCompute.h"
#include "Utils.h"
namespace atb {
template <typename ParamType>
class OpParamCache {
public:
static OpParamCache& getInstance();
atb::Operation* getOperation(const ParamType& param, const std::string& name);
atb::Operation* getOperation(uint64_t hash_id);
void saveOperation(uint64_t hash_id, atb::Operation* op);
private:
OpParamCache();
OpParamCache(const OpParamCache&) = delete;
OpParamCache& operator=(const OpParamCache&) = delete;
~OpParamCache();
std::unordered_map<uint64_t, atb::Operation*> op_map_;
};
template <typename ParamType>
atb::Operation* CreateAtbOperation(const ParamType& param, const std::string& name)
{
atb::Operation* op = nullptr;
atb::CreateOperation(param, &op);
TORCH_CHECK(op != nullptr, name, " CreateOperation failed!");
return op;
}
template <typename ParamType>
OpParamCache<ParamType>& OpParamCache<ParamType>::getInstance()
{
thread_local OpParamCache instance;
return instance;
}
template <typename ParamType>
atb::Operation* OpParamCache<ParamType>::getOperation(const ParamType& param, const std::string& name)
{
const auto is_capturing = static_cast<int>(c10_npu::currentStreamCaptureStatusMayInitCtx());
if (is_capturing) {
return CreateAtbOperation(param, name);
} else {
uint64_t hashValue = computeHash(param);
{
auto op_cache = op_map_.find(hashValue);
if (op_cache != op_map_.end()) {
return op_cache->second;
}
atb::Operation* op = CreateAtbOperation(param, name);
op_map_[hashValue] = op;
return op;
}
}
}
template <typename ParamType>
atb::Operation* OpParamCache<ParamType>::getOperation(uint64_t hash_id)
{
auto op_cache = op_map_.find(hash_id);
if (op_cache != op_map_.end()) {
return op_cache->second;
}
atb::Operation* op = nullptr;
return op;
}
template <typename ParamType>
void OpParamCache<ParamType>::saveOperation(uint64_t hash_id, atb::Operation* op)
{
op_map_[hash_id] = op;
return ;
}
template <typename ParamType>
OpParamCache<ParamType>::OpParamCache()
{
atb::utils::ContextManager::GetInstance();
}
template <typename ParamType>
OpParamCache<ParamType>::~OpParamCache()
{
for (auto& op_item: op_map_) {
DestroyOperation(op_item.second);
}
}
}
#endif