* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include "ascir/Target/Asc/External/Arith.h"
using namespace mlir;
namespace {
constexpr uint32_t bitWidth32 = 32;
constexpr uint32_t bitWidth16 = 16;
}
LogicalResult mlir::printOperation(CodeEmitter& emitter, arith::ConstantOp constantOp)
{
Operation* operation = constantOp.getOperation();
Attribute value = constantOp.getValue();
return printConstantOp(emitter, operation, value);
}
LogicalResult mlir::printOperation(CodeEmitter& emitter, arith::MulUIExtendedOp op)
{
auto& os = emitter.ostream();
auto resultType = op->getResult(1).getType();
auto lhs = emitter.getOrCreateName(op.getLhs());
auto rhs = emitter.getOrCreateName(op.getRhs());
if (failed(emitter.emitVariableDeclaration(op->getResult(0), false))) {
return failure();
}
os << " = " << lhs << " * " << rhs << ";\n";
auto typeSize = resultType.getIntOrFloatBitWidth();
if (failed(emitter.emitVariableDeclaration(op->getResult(1), false))) {
return failure();
}
if (typeSize == bitWidth32) {
os << " = (static_cast<uint64_t>(" << lhs << ") * static_cast<uint64_t>(" << rhs << ")) >> 32";
} else if (typeSize == bitWidth16) {
os << " = (static_cast<uint32_t>(" << lhs << ") * static_cast<uint32_t>(" << rhs << ")) >> 16";
} else {
llvm_unreachable("not implemented");
}
return success();
}
LogicalResult mlir::printOperation(CodeEmitter& emitter, arith::CmpIOp op)
{
if (failed(emitter.emitAssignPrefix(*op.getOperation()))) {
return failure();
}
auto& os = emitter.ostream();
os << emitter.getOrCreateName(op.getLhs()) << " ";
switch (op.getPredicate()) {
case arith::CmpIPredicate::eq:
os << "==";
break;
case arith::CmpIPredicate::ne:
os << "!=";
break;
case arith::CmpIPredicate::sle:
case arith::CmpIPredicate::ule:
os << "<=";
break;
case arith::CmpIPredicate::slt:
case arith::CmpIPredicate::ult:
os << "<";
break;
case arith::CmpIPredicate::sge:
case arith::CmpIPredicate::uge:
os << ">=";
break;
case arith::CmpIPredicate::sgt:
case arith::CmpIPredicate::ugt:
os << ">";
break;
}
os << " " << emitter.getOrCreateName(op.getRhs());
return success();
}
LogicalResult mlir::printOperation(CodeEmitter& emitter, arith::CmpFOp op)
{
if (failed(emitter.emitAssignPrefix(*op.getOperation()))) {
return failure();
}
auto& os = emitter.ostream();
os << emitter.getOrCreateName(op.getLhs()) << " ";
switch (op.getPredicate()) {
case arith::CmpFPredicate::OEQ:
case arith::CmpFPredicate::UEQ:
os << "==";
break;
case arith::CmpFPredicate::ONE:
case arith::CmpFPredicate::UNE:
os << "!=";
break;
case arith::CmpFPredicate::OLE:
case arith::CmpFPredicate::ULE:
os << "<=";
break;
case arith::CmpFPredicate::OLT:
case arith::CmpFPredicate::ULT:
os << "<";
break;
case arith::CmpFPredicate::OGE:
case arith::CmpFPredicate::UGE:
os << ">=";
break;
case arith::CmpFPredicate::OGT:
case arith::CmpFPredicate::UGT:
os << ">";
break;
case arith::CmpFPredicate::AlwaysFalse:
case arith::CmpFPredicate::AlwaysTrue:
case arith::CmpFPredicate::ORD:
case arith::CmpFPredicate::UNO:
llvm_unreachable("unsupported predicate in arith.cmpf operation");
}
os << " " << emitter.getOrCreateName(op.getRhs());
return success();
}
LogicalResult mlir::printOperation(CodeEmitter& emitter, arith::BitcastOp op)
{
FAIL_OR(emitter.emitAssignPrefix(*op.getOperation()));
auto& os = emitter.ostream();
os << "*reinterpret_cast<";
FAIL_OR(emitter.emitType(op.getLoc(), op.getType()));
os << "*>(&" << emitter.getOrCreateName(op.getIn()) << ")";
return success();
}
LogicalResult mlir::printOperation(CodeEmitter& emitter, arith::SelectOp op)
{
if (failed(emitter.emitAssignPrefix(*op.getOperation()))) {
return failure();
}
auto& os = emitter.ostream();
os << emitter.getOrCreateName(op.getCondition()) << " ? " << emitter.getOrCreateName(op.getTrueValue()) << " : "
<< emitter.getOrCreateName(op.getFalseValue());
return success();
}
LogicalResult mlir::printOperation(CodeEmitter& emitter, arith::IndexCastOp op)
{
if (failed(emitter.emitAssignPrefix(*op.getOperation()))) {
return failure();
}
auto& os = emitter.ostream();
os << "static_cast<";
if (failed(emitter.emitType(op.getLoc(), op.getOut().getType()))) {
return failure();
}
os << ">(" << emitter.getOrCreateName(op.getIn()) << ")";
return success();
}