* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file codegen_cube.cpp
* \brief
*/
#include "codegen_op_npu.h"
#include "codegen/utils/codegen_utils.h"
#include "codegen/symbol_mgr/codegen_symbol.h"
#include "securec.h"
namespace npu::tile_fwk {
std::string CodeGenOpNPU::PrintMatmulTileTensor(
bool isAcc, std::unordered_map<OperandType, std::string>& tensorWithMemType) const
{
std::ostringstream oss;
bool hasBias = tensorWithMemType.count(OperandType::BUF_BT);
int64_t transModeNum = 0;
GetOpAttr(OpAttributeKey::transMode, transModeNum);
TransMode transMode = static_cast<TransMode>(transModeNum);
std::string transModeStr = "TransMode::CAST_NONE";
if (transMode == TransMode::CAST_RINT) {
transModeStr = "TransMode::CAST_RINT";
} else if (transMode == TransMode::CAST_ROUND) {
transModeStr = "TransMode::CAST_ROUND";
}
std::vector<std::string> paramList = {
tensorWithMemType[OperandType::BUF_L0C], tensorWithMemType[OperandType::BUF_L0A],
tensorWithMemType[OperandType::BUF_L0B]};
oss << tileOpName;
if (hasBias) {
paramList.emplace_back(tensorWithMemType[OperandType::BUF_BT]);
oss << WrapParamByAngleBrackets({transModeStr});
oss << WrapParamByParentheses(paramList) << ";\n";
return oss.str();
}
oss << WrapParamByAngleBrackets({std::to_string(isAcc), transModeStr});
oss << WrapParamByParentheses(paramList) << ";\n";
return oss.str();
}
std::string CodeGenOpNPU::PrintMatmulTileTensor(bool isAcc) const
{
std::unordered_map<OperandType, std::string> tensorWithMemType;
for (int i = 0; i < operandCnt; i++) {
tensorWithMemType.emplace(operandType[i], QueryTileTensorNameByIdx(i));
}
bool hasBias = tensorWithMemType.count(OperandType::BUF_BT);
bool isMXMad = tensorWithMemType.count(OperandType::BUF_L0AMX) || tensorWithMemType.count(OperandType::BUF_L0BMX);
if (!isMXMad) {
return PrintMatmulTileTensor(isAcc, tensorWithMemType);
}
std::ostringstream oss;
std::vector<std::string> mxParamList = {
tensorWithMemType[OperandType::BUF_L0C], tensorWithMemType[OperandType::BUF_L0A],
tensorWithMemType[OperandType::BUF_L0AMX], tensorWithMemType[OperandType::BUF_L0B],
tensorWithMemType[OperandType::BUF_L0BMX]};
oss << "MatmulMX";
if (hasBias) {
mxParamList.emplace_back(tensorWithMemType[OperandType::BUF_BT]);
oss << WrapParamByParentheses(mxParamList) << ";\n";
return oss.str();
}
oss << WrapParamByAngleBrackets({std::to_string(isAcc)});
oss << WrapParamByParentheses(mxParamList) << ";\n";
return oss.str();
}
std::string CodeGenOpNPU::GenCubeOp(bool zeroC) const
{
if (isSupportLayout) {
return PrintMatmulTileTensor(!zeroC);
}
unsigned uf = 0;
std::string aVar = sm->QueryVarNameByTensorMagic(operandWithMagic[ID1]);
std::string bVar = sm->QueryVarNameByTensorMagic(operandWithMagic[ID2]);
std::string cVar = sm->QueryVarNameByTensorMagic(operandWithMagic[ID0]);
std::string aDtypeStr = DataType2CCEStr(operandDtype[ID1]);
std::string bDtypeStr = DataType2CCEStr(operandDtype[ID2]);
std::string cDtypeStr = DataType2CCEStr(operandDtype[ID0]);
std::ostringstream oss;
if (isDynamicFunction) {
auto l0cShapeDyn = dynamicValidShape[ID0];
auto l0aShapeDyn = dynamicValidShape[ID1];
auto l0bShapeDyn = dynamicValidShape[ID2];
auto mSymbol = l0cShapeDyn[ID0];
auto kSymbol = l0aShapeDyn[ID1];
auto nSymbol = l0cShapeDyn[ID1];
bool hasBias = 0;
if (opAttrs.count(OP_ATTR_PREFIX + "has_bias")) {
hasBias = AnyCast<bool>(opAttrs.at(OP_ATTR_PREFIX + "has_bias"));
}
std::string biasStr = ", " + std::to_string(hasBias);
oss << tileOpName << "<" << cDtypeStr << ", " << aDtypeStr << ", " << bDtypeStr << ", " << offset[ID0][ID0]
<< ", " << offset[ID0][ID1] << biasStr << ">"
<< "((" << GetAddrTypeByOperandType(operandType[ID0]) << " " << cDtypeStr << "*)" << cVar << ", "
<< "(" << GetAddrTypeByOperandType(operandType[ID1]) << " " << aDtypeStr << "*)" << aVar << ", "
<< "(" << GetAddrTypeByOperandType(operandType[ID2]) << " " << bDtypeStr << "*)" << bVar << ", "
<< SymbolicExpressionTable::BuildExpression(mSymbol) << ", "
<< SymbolicExpressionTable::BuildExpression(kSymbol) << ", "
<< SymbolicExpressionTable::BuildExpression(nSymbol) << ", " << (zeroC ? "true" : "false") << ", " << uf
<< ", " << SymbolicExpressionTable::BuildExpression(l0cShapeDyn[ID0]) << ", "
<< SymbolicExpressionTable::BuildExpression(l0cShapeDyn[ID1]) << ");\n";
} else {
int64_t m = shape[ID0][ID0];
int64_t k = shape[ID1][ID1];
int64_t n = shape[ID0][ID1];
oss << tileOpName << "<" << cDtypeStr << ", " << aDtypeStr << ", " << bDtypeStr << ", " << offset[ID0][ID0]
<< ", " << offset[ID0][ID1] << ", " << m << ", " << n << ">"
<< "((" << GetAddrTypeByOperandType(operandType[ID0]) << " " << cDtypeStr << "*)" << cVar << ", "
<< "(" << GetAddrTypeByOperandType(operandType[ID1]) << " " << aDtypeStr << "*)" << aVar << ", "
<< "(" << GetAddrTypeByOperandType(operandType[ID2]) << " " << bDtypeStr << "*)" << bVar << ", " << m
<< ", " << k << ", " << n << ", " << (zeroC ? "true" : "false") << ", " << uf << ");\n";
}
return oss.str();
}
std::string CodeGenOpNPU::GenCubeOpMatmul() const { return GenCubeOp(true); }
std::string CodeGenOpNPU::GenCubeOpMatmulAcc() const { return GenCubeOp(false); }
std::string CodeGenOpNPU::GenParamsStr(const std::unordered_set<int32_t>& skipOperands) const
{
std::vector<std::string> params;
for (int i = 0; i < MAX_OPERANDS; i++) {
if (operand[i] == NULL_OPERAND) {
continue;
}
std::string dtypeStr = DataType2CCEStr(operandDtype[i]);
std::string prefix = GetAddrTypeByOperandType(operandType[i]);
if (skipOperands.find(i) != skipOperands.end()) {
continue;
}
if (operandType[i] == BUF_DDR) {
std::string var = GenGmParamVar(i);
std::ostringstream oss;
oss << "(" << prefix << " " << dtypeStr << "*)" << var;
params.emplace_back(oss.str());
} else {
std::string var = sm->QueryVarNameByTensorMagic(operandWithMagic[i]);
if (opCode != Opcode::OP_L1_TO_L0A && opCode != Opcode::OP_L1_TO_L0B && opCode != Opcode::OP_L1_TO_L0_BT &&
opCode != Opcode::OP_L1_TO_L0_AT) {
AppendLocalBufferVarOffset({{static_cast<unsigned>(i), std::ref(var)}});
}
std::ostringstream oss;
CODEGEN_LOGD("GenParamsStr var: %s", var.c_str());
oss << "(" << prefix << " " << dtypeStr << "*)" << var;
params.emplace_back(oss.str());
}
}
return JoinString(params, ", ");
}
}