* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "ascir/Target/Asc/Basic/OtherOps.h"
using namespace mlir;
using namespace mlir::ascendc;
namespace {
struct AippMemberInfo {
const char* aippMemberName;
const std::vector<const char*> subMemberNames;
};
const AippMemberInfo* getAippMemberInfo(size_t index)
{
static const std::vector<AippMemberInfo> memberInfos = {
{"paddingParams", {"paddingMode", "paddingValueCh0", "paddingValueCh1", "paddingValueCh2", "paddingValueCh3"}},
{"swapParams", {"isSwapRB", "isSwapUV", "isSwapAX"}},
{"singleLineParams", {"isSingleLineCopy"}},
{"dtcParams",
{"dtcMeanCh0", "dtcMeanCh1", "dtcMeanCh2", "dtcMinCh0", "dtcMinCh1", "dtcMinCh2", "dtcVarCh0", "dtcVarCh1",
"dtcVarCh2", "dtcRoundMode"}},
{"cPaddingParams", {"cPaddingMode", "cPaddingValue"}},
{"cscParams",
{"isEnableCsc", "cscMatrixR0C0", "cscMatrixR0C1", "cscMatrixR0C2", "cscMatrixR1C0", "cscMatrixR1C1",
"cscMatrixR1C2", "cscMatrixR2C0", "cscMatrixR2C1", "cscMatrixR2C2", "cscBiasIn0", "cscBiasIn1", "cscBiasIn2",
"cscBiasOut0", "cscBiasOut1", "cscBiasOut2"}}};
if (index >= memberInfos.size()) {
return nullptr;
}
return &memberInfos[index];
}
LogicalResult printAippMemberAssignment(CodeEmitter& emitter, ascendc::ConstructOp op, size_t memberIndex)
{
auto& os = emitter.ostream();
const AippMemberInfo* memberInfoPtr = getAippMemberInfo(memberIndex);
if (!memberInfoPtr) {
return op.emitError("Internal Error: Index out of bounds when accessing AippMemberInfo for member index ")
<< memberIndex;
}
auto subStructOp = dyn_cast_or_null<ascendc::ConstructOp>(op->getOperand(memberIndex).getDefiningOp());
if (!subStructOp) {
return op.emitError("Internal Error: Expected operand of AippParams to be a ConstructOp.");
}
if (subStructOp->getNumOperands() != memberInfoPtr->subMemberNames.size()) {
return op.emitError("Internal Error: Mismatch in operand count for AIPP member ")
<< memberInfoPtr->aippMemberName;
}
for (size_t i = 0; i < subStructOp->getNumOperands(); ++i) {
os << emitter.getOrCreateName(op->getResult(0)) << "." << memberInfoPtr->aippMemberName << "."
<< memberInfoPtr->subMemberNames[i] << " = " << emitter.getOrCreateName(subStructOp->getOperand(i)) << ";\n";
}
return success();
}
LogicalResult printAippStructConstruction(CodeEmitter& emitter, ascendc::ConstructOp op)
{
auto& os = emitter.ostream();
mlir::Type resultType = op->getResult(0).getType();
return llvm::TypeSwitch<mlir::Type, LogicalResult>(resultType)
.Case([&](ascendc::AippParamsType) -> LogicalResult {
constexpr size_t paddingParamsOpIndex = 0;
constexpr size_t templateTypeMemberIndex = 1;
if (op->getNumOperands() <= paddingParamsOpIndex) {
return op.emitError("Internal Error: AippParams is missing operands.");
}
auto paddingConstructOp =
dyn_cast_or_null<ascendc::ConstructOp>(op->getOperand(paddingParamsOpIndex).getDefiningOp());
if (!paddingConstructOp || paddingConstructOp->getNumOperands() <= templateTypeMemberIndex) {
return op.emitError("Internal Error: Cannot deduce template type from paddingParams.");
}
mlir::Type templateType = paddingConstructOp->getOperand(templateTypeMemberIndex).getType();
os << "AscendC::AippParams<";
if (failed(emitter.emitType(op.getLoc(), templateType))) {
return failure();
}
os << "> " << emitter.getOrCreateName(op->getResult(0)) << ";\n";
for (size_t i = 0; i < op->getNumOperands(); ++i) {
if (failed(printAippMemberAssignment(emitter, op, i))) {
return failure();
}
}
return success();
})
.Case<
ascendc::AippPaddingParamsType, ascendc::AippSwapParamsType, ascendc::AippSingleLineParamsType,
ascendc::AippDataTypeConvParamsType, ascendc::AippChannelPaddingParamsType,
ascendc::AippColorSpaceConvParamsType>([&](auto) -> LogicalResult { return success(); })
.Default([](auto) -> LogicalResult { return failure(); });
}
}
LogicalResult mlir::ascendc::printOperation(CodeEmitter& emitter, ascendc::ConstructOp op)
{
if (succeeded(printAippStructConstruction(emitter, op))) {
return success();
}
auto& os = emitter.ostream();
if (op.getIsStatic()) {
os << "static ";
}
if (op.getIsConstexpr()) {
os << "constexpr ";
}
FAIL_OR(emitter.emitVariableDeclaration(op->getResult(0), false));
if (op.getNumOperands() == 0) {
return success();
}
os << '{';
auto emitOperand = [&os, &emitter](Value operand, Type type) {
if (operand.getType() == type) {
os << emitter.getOrCreateName(operand);
return;
}
if (isa<MemRefType>(type)) {
os << "reinterpret_cast<";
} else {
os << "static_cast<";
}
(void)emitter.emitType(operand.getLoc(), type);
os << ">(" << emitter.getOrCreateName(operand) << ")";
};
SmallVector<Type> types;
if (auto typesAttr = op.getTypesAttr()) {
if (typesAttr.size() != op->getNumOperands()) {
return emitError(op.getLoc(), "Expect the size of typesAttr equals to numbers of operands");
}
llvm::transform(typesAttr, std::back_inserter(types), [](Attribute a) { return cast<TypeAttr>(a).getValue(); });
} else {
types.append(op->getOperandTypes().begin(), op->getOperandTypes().end());
}
llvm::interleaveComma(llvm::zip_equal(op.getOperands(), types), os, [&emitOperand](auto pair) {
const auto& [operand, type] = pair;
emitOperand(operand, type);
});
os << '}';
return success();
}
LogicalResult mlir::ascendc::printOperation(CodeEmitter& emitter, ascendc::FftsCrossCoreSyncOp op)
{
auto& os = emitter.ostream();
os << "ffts_cross_core_sync(" << ascendc::stringifyEnum(op.getPipe()).upper() << ", "
<< emitter.getOrCreateName(op.getConfig()) << ")";
return success();
}
LogicalResult mlir::ascendc::printOperation(CodeEmitter& emitter, ascendc::GetMrgSortResultOp op)
{
auto& os = emitter.ostream();
Value mrgSortList1Value = op.getMrgSortList1();
Value mrgSortList2Value = op.getMrgSortList2();
Value mrgSortList3Value = op.getMrgSortList3();
Value mrgSortList4Value = op.getMrgSortList4();
os << "uint16_t " << emitter.getOrCreateName(mrgSortList1Value) << ";\n";
os << "uint16_t " << emitter.getOrCreateName(mrgSortList2Value) << ";\n";
os << "uint16_t " << emitter.getOrCreateName(mrgSortList3Value) << ";\n";
os << "uint16_t " << emitter.getOrCreateName(mrgSortList4Value) << ";\n";
os << ascNamespace << "::" << op.getAPIName();
os << "(" << emitter.getOrCreateName(mrgSortList1Value);
os << ", " << emitter.getOrCreateName(mrgSortList2Value);
os << ", " << emitter.getOrCreateName(mrgSortList3Value);
os << ", " << emitter.getOrCreateName(mrgSortList4Value) << ")";
return success();
}
LogicalResult mlir::ascendc::printOperation(CodeEmitter& emitter, ascendc::MrgSortOp op)
{
static int elementCountListCounter = 0;
auto uniqueId = std::to_string(elementCountListCounter++);
auto& os = emitter.ostream();
auto elementCountListName = (emitter.getOrCreateName(op.getDst()) + "_element_count_list_" + uniqueId).str();
auto sortedNumName = (emitter.getOrCreateName(op.getDst()) + "_sorted_num_" + uniqueId).str();
os << "uint16_t " << elementCountListName << "[] = {";
llvm::interleaveComma(op.getElementCountList(), os, [&](Value operand) { os << emitter.getOrCreateName(operand); });
os << "};\n";
os << "uint32_t " << sortedNumName << "[] = {";
llvm::interleaveComma(op.getSortedNum(), os, [&](Value operand) { os << emitter.getOrCreateName(operand); });
os << "};\n";
os << ascNamespace << "::" << op.getAPIName();
auto tensorType = cast<ascendc::LocalTensorType>(op.getDst().getType()).getElementType();
os << "<";
FAIL_OR(emitter.emitType(op.getLoc(), tensorType));
os << ", " << op.getIsExhaustedSuspension() << ">"
<< "(" << emitter.getOrCreateName(op.getDst()) << ", " << emitter.getOrCreateName(op.getSortList()) << ", "
<< elementCountListName << ", " << sortedNumName << ", " << emitter.getOrCreateName(op.getValidBit()) << ", "
<< emitter.getOrCreateName(op.getRepeatTime()) << ")";
return success();
}
LogicalResult mlir::ascendc::printOperation(CodeEmitter& emitter, ascendc::SortOp op)
{
auto& os = emitter.ostream();
os << ascNamespace << "::" << op.getAPIName();
auto tensorType = cast<ascendc::LocalTensorType>(op.getDst().getType()).getElementType();
os << "<";
FAIL_OR(emitter.emitType(op.getLoc(), tensorType));
os << ", " << op.getIsFullSort() << ">"
<< "(" << emitter.getOrCreateName(op.getDst()) << ", " << emitter.getOrCreateName(op.getConcat()) << ", "
<< emitter.getOrCreateName(op.getIndex()) << ", " << emitter.getOrCreateName(op.getTmp()) << ", "
<< emitter.getOrCreateName(op.getRepeatTime()) << ")";
return success();
}
LogicalResult mlir::ascendc::printOperation(CodeEmitter& emitter, ascendc::PopStackBufferOp op)
{
auto& os = emitter.ostream();
os << ascNamespace << "::" << op.getAPIName() << "<";
FAIL_OR(emitter.emitType(op.getLoc(), op.getTensor().getType().getElementType()));
os << ", ";
CodeEmitter::emitTPosition(os, op.getPos());
os << ">(" << emitter.getOrCreateName(op.getTensor()) << ")";
return success();
}
LogicalResult mlir::ascendc::printOperation(CodeEmitter& emitter, ascendc::SetFftsBaseAddrOp op)
{
auto& os = emitter.ostream();
os << "set_ffts_base_addr(*" << emitter.getOrCreateName(op.getOperand()) << ")";
return success();
}
LogicalResult mlir::printOperation(CodeEmitter& emitter, LLVM::UndefOp op)
{
return emitter.emitVariableDeclaration(op->getResult(0), false);
}
LogicalResult mlir::ascendc::printOperation(CodeEmitter& emitter, ascendc::ResetMaskOp op);
LogicalResult mlir::ascendc::printOperation(CodeEmitter& emitter, ascendc::FixpipeOp op)
{
auto& os = emitter.ostream();
FAIL_OR(printFixpipeTemplate(emitter, op));
os << "(" << emitter.getOrCreateName(op.getDst()) << ", " << emitter.getOrCreateName(op.getSrc()) << ", "
<< emitter.getOrCreateName(op.getIntriParams()) << ")";
return success();
}
LogicalResult mlir::ascendc::printOperation(CodeEmitter& emitter, ascendc::FixpipeWithWorkspaceOp op)
{
auto& os = emitter.ostream();
FAIL_OR(printFixpipeTemplate(emitter, op));
os << "(" << emitter.getOrCreateName(op.getDst()) << ", " << emitter.getOrCreateName(op.getSrc()) << ", "
<< emitter.getOrCreateName(op.getCbufWorkspace()) << ", " << emitter.getOrCreateName(op.getIntriParams()) << ")";
return success();
}
LogicalResult mlir::ascendc::printOperation(CodeEmitter& emitter, ascendc::GetStoreAtomicConfigOp op)
{
auto& os = emitter.ostream();
Value atomicTypeValue = op.getAtomicType();
Value atomicOpValue = op.getAtomicOp();
os << "uint16_t " << emitter.getOrCreateName(atomicTypeValue) << ";\n";
os << "uint16_t " << emitter.getOrCreateName(atomicOpValue) << ";\n";
os << ascNamespace << "::" << op.getAPIName();
os << "(" << emitter.getOrCreateName(atomicTypeValue);
os << ", " << emitter.getOrCreateName(atomicOpValue) << ")";
return success();
}