* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_arg.h
* \brief
*/
#ifndef __ASCENDTENSOR_KERNEL_ARG_H__
#define __ASCENDTENSOR_KERNEL_ARG_H__
#include <string>
#include <memory>
#include <map>
#include <iostream>
#include "interface/tensor/logical_tensor.h"
namespace npu::tile_fwk {
struct ArgInfo {
std::string name;
std::vector<int> offset;
std::shared_ptr<RawTensor> tensor;
};
class KernelArg {
public:
bool IsOpArgs(std::shared_ptr<RawTensor> tensor)
{
return op_args_.find(tensor->rawmagic) != op_args_.end();
}
std::vector<std::string> GetOpArgName()
{
std::vector<std::string> res;
for (const auto& ele : op_args_) {
res.emplace_back(ele.second.name);
}
return res;
}
void SetOpArgs(std::vector<std::shared_ptr<LogicalTensor>> tensors)
{
for (auto tensor : tensors) {
int magic = tensor->tensor->rawmagic;
if (IsOpArgs(tensor->tensor)) {
auto& arg = op_args_[magic];
assert(tensor.GetStorage()->Symbol() == arg.name);
assert(tensor->offset == arg.offset);
} else {
std::cout << "Add args " << magic << " " << tensor.GetStorage()->Symbol() << std::endl;
op_args_[magic] = {tensor.GetStorage()->Symbol(), tensor->offset, tensor->tensor};
}
}
}
std::string GetOpArgName(std::shared_ptr<RawTensor> tensor)
{
for ([[maybe_unused]] const auto& ele : op_args_) {
if (op_args_.find(tensor->rawmagic) != op_args_.end()) {
return op_args_[tensor->rawmagic].name;
}
}
return "";
}
void clear() { op_args_.clear(); }
private:
std::map<int, ArgInfo> op_args_;
};
inline KernelArg& GetKernelArg()
{
static KernelArg ka;
return ka;
}
}
#endif