* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file op_infer_shape_impl.h
* \brief
*/
#ifndef DEVICE_INFER_SHAPE_H
#define DEVICE_INFER_SHAPE_H
#include <unordered_map>
#include <functional>
#include "opcode.h"
#include "operation.h"
namespace npu::tile_fwk {
using FuncType = std::function<void(Operation* op, std::vector<std::vector<SymbolicScalar>>&)>;
class InferShapeRegistry {
private:
InferShapeRegistry() = default;
~InferShapeRegistry() = default;
public:
static InferShapeRegistry& GetInstance()
{
static InferShapeRegistry instance;
return instance;
}
void RegisterInferShapeFunc(const Opcode opcode, FuncType func) { inferShapeFuncs_[opcode] = func; }
void CallInferShapeFunc(Operation* op)
{
const Opcode opcode = op->GetOpcode();
std::vector<std::vector<SymbolicScalar>> outValidShapes;
auto it = inferShapeFuncs_.find(opcode);
if (it != inferShapeFuncs_.end()) {
it->second(op, outValidShapes);
} else {
PASS_LOGW("Infer shape failed, opcode [%s] doesn't support infer shape.", op->GetOpcodeStr().c_str());
for (auto output : op->GetOOperands()) {
auto immShape = OpImmediate::Specified(output->GetShape());
if (output->GetDynValidShape().empty()) {
std::vector<SymbolicScalar> validShape;
for (auto immDim : immShape) {
validShape.push_back(immDim.GetSpecifiedValue());
}
outValidShapes.push_back(validShape);
} else {
outValidShapes.push_back(output->GetDynValidShape());
}
}
}
for (size_t i = 0; i < op->GetOOperands().size(); ++i) {
if (op->GetOOperands()[i]->GetDynValidShape().empty() || op-> GetOpcode() == Opcode::OP_COPY_OUT || op-> GetOpcode() == Opcode::OP_ASSEMBLE) {
op->GetOOperands()[i]->UpdateDynValidShape(outValidShapes[i]);
}
}
}
private:
std::unordered_map<Opcode, FuncType> inferShapeFuncs_;
};
#define REGISTER_INFER_SHAPE_FUNC(OpCoreStr, OpType, FuncName) \
class OpCoreStr##Register { \
public: \
OpCoreStr##Register() { InferShapeRegistry::GetInstance().RegisterInferShapeFunc(OpType, FuncName); } \
}; \
static OpCoreStr##Register OpCoreStr##_register
}
#endif