* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "ascir/Target/Asc/EmitAsc.h"
#include "ascir/Dialect/Asc/IR/Asc.h"
#include "ascir/Dialect/EmitAsc/IR/EmitAsc.h"
#include "mlir/IR/BuiltinAttributes.h"
using namespace mlir;
using namespace mlir::emitasc;
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::CallOpaqueOp op)
{
auto& os = emitter.ostream();
FAIL_OR(emitter.emitAssignPrefix(*op.getOperation()));
os << op.getCallee() << '(';
llvm::interleaveComma(op.getOperands(), os, [&](Value operand) { os << emitter.getOrCreateName(operand); });
os << ')';
return success();
}
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::CopyStructOp op)
{
FAIL_OR(emitter.emitType(op.getLoc(), op.getType()));
auto base = emitter.getOrCreateName(op.getBase());
auto result = emitter.getOrCreateName(op.getResult());
auto& os = emitter.ostream();
os << ' ' << result << ";\n";
os << "for (size_t i = 0; i < sizeof(" << result << "); i++) {\n";
os.indent() << "auto byte = reinterpret_cast<";
if (auto attr = dyn_cast_if_present<IntegerAttr>(op.getBase().getType().getMemorySpace())) {
auto asInt = static_cast<uint8_t>(attr.getValue().getSExtValue());
if (auto addrSpace = ascendc::symbolizeAddressSpace(asInt)) {
emitter.emitAddressSpace(*addrSpace);
}
}
os << "uint8_t*>(" << base << ")[i];\n";
os << "reinterpret_cast<uint8_t*>(&" << result << ")[i] = byte;\n";
os.unindent() << '}';
return success();
}
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::DereferenceOp op)
{
auto& os = emitter.ostream();
FAIL_OR(emitter.emitType(op.getLoc(), op.getType()));
os << "& " << emitter.getOrCreateName(op.getResult()) << " = *" << emitter.getOrCreateName(op.getBase());
return success();
}
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::MemberOp op)
{
auto& os = emitter.ostream();
FAIL_OR(emitter.emitAssignPrefix(*op.getOperation()));
os << emitter.getOrCreateName(op.getBase());
if (isa<MemRefType>(op.getBase().getType()))
os << "->";
else
os << '.';
os << op.getField();
return success();
}
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::MemberPtrOp op)
{
auto& os = emitter.ostream();
FAIL_OR(emitter.emitAssignPrefix(*op.getOperation()));
os << "reinterpret_cast<";
FAIL_OR(emitter.emitType(op.getLoc(), op.getType()));
os << ">(" << '&' << emitter.getOrCreateName(op.getBase()) << "->";
if (auto attr = op.getFieldAttr()) {
os << attr.str();
} else {
os << emitter.structFieldNamePrefix << op.getIndex();
}
os << ')';
return success();
}
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::DeclarePyStructOp op)
{
auto& os = emitter.ostream();
os << "#pragma pack(push, 8)\n";
auto pType = dyn_cast<emitasc::PyStructType>(op.getPystruct());
os << "struct " << pType.getNameAttr().getValue() << " {\n";
os.indent();
for (auto [typeAttr, nameAttr] : llvm::zip_equal(pType.getTypesAttr(), pType.getNamesAttr())) {
FAIL_OR(emitter.emitType(op.getLoc(), cast<TypeAttr>(typeAttr).getValue()));
os << " " << cast<StringAttr>(nameAttr).getValue() << ";\n";
}
os.unindent() << "};\n";
os << "#pragma pack(pop)\n";
return success();
}
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::MemberRefOp op)
{
auto& os = emitter.ostream();
FAIL_OR(emitter.emitType(op.getLoc(), op.getType()));
os << "& " << emitter.getOrCreateName(op.getResult()) << " = reinterpret_cast<";
FAIL_OR(emitter.emitType(op.getLoc(), op.getType()));
os << "&>(" << emitter.getOrCreateName(op.getBase()) << "->";
if (auto attr = op.getFieldAttr()) {
os << attr.str();
} else {
os << emitter.structFieldNamePrefix << op.getIndex();
}
os << ')';
return success();
}
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::PtrOffsetOp op)
{
FAIL_OR(emitter.emitAssignPrefix(*op.getOperation()));
auto& os = emitter.ostream();
os << emitter.getOrCreateName(op.getBase()) << " + ";
if (auto offset = op.getDynamicOffset()) {
os << emitter.getOrCreateName(offset);
} else {
FAIL_OR(emitter.emitAttribute(op.getLoc(), op.getStaticOffsetAttr()));
}
return success();
}
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::ReinterpretCastOp op)
{
auto& os = emitter.ostream();
FAIL_OR(emitter.emitAssignPrefix(*op.getOperation()));
os << "reinterpret_cast<";
FAIL_OR(emitter.emitType(op.getLoc(), op.getType()));
os << ">(" << emitter.getOrCreateName(op.getSource()) << ')';
return success();
}
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::SetMemberOp op)
{
auto& os = emitter.ostream();
os << emitter.getOrCreateName(op.getBase()) << "." << op.getField() << " = "
<< emitter.getOrCreateName(op.getValue());
return success();
}
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::VariableOp op)
{
auto& os = emitter.ostream();
auto loc = op.getLoc();
auto res = op.getResult();
auto resType = res.getType();
FAIL_OR(emitter.emitType(op.getLoc(), resType.getElementType()));
os << ' ' << emitter.getOrCreateName(res);
for (auto size : resType.getShape()) {
os << '[' << size << ']';
}
os << '{';
if (op.isStatic()) {
FAIL_OR(emitter.emitAttribute(loc, op.getStaticInitAttr()));
} else {
os << emitter.getOrCreateName(op.getDynamicInit());
}
os << '}';
return success();
}
LogicalResult mlir::emitasc::printOperation(CodeEmitter& emitter, emitasc::VerbatimOp op)
{
auto& os = emitter.ostream();
auto args = op.getArgs();
auto code = op.getValue();
if (args.empty()) {
os << code;
return success();
}
std::string result;
result.reserve(2 * code.size());
size_t i = 1;
size_t rem = 0;
const char* data = code.data();
while (i < code.size()) {
if (code[i - 1] != '$') {
i++;
continue;
}
size_t j = i;
while (j < code.size() && isdigit(code[j])) {
j++;
}
if (j - i > 0) {
size_t index = 0;
auto fcResult = std::from_chars(data + i, data + j, index);
if (!std::make_error_code(fcResult.ec) && index < args.size()) {
std::copy(data + rem, data + i - 1, std::back_inserter(result));
result += emitter.getOrCreateName(args[index]);
rem = j;
i = j;
continue;
}
}
i++;
}
std::copy(data + rem, data + code.size(), std::back_inserter(result));
os << result;
return success();
}