#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/IR/PatternMatch.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
namespace {
struct PrintOpConversion : public ConvertOpToLLVMPattern<triton::PrintOp> {
explicit PrintOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit)
: mlir::ConvertOpToLLVMPattern<triton::PrintOp>(typeConverter, benefit),
targetInfo(targetInfo) {}
LogicalResult
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
std::array<Value, 3> pid;
auto module = op->getParentOfType<ModuleOp>();
for (auto axis : {ProgramIDDim::X, ProgramIDDim::Y, ProgramIDDim::Z})
pid[(int)axis] = targetInfo.programId(rewriter, loc, module, axis);
if (op.getNumOperands() == 0) {
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
os << "pid (" << getFormatSubstr(pid[0]) << ", "
<< getFormatSubstr(pid[1]) << ", " << getFormatSubstr(pid[2]) << ")"
<< op.getPrefix();
llPrintf(formatStr, {pid[0], pid[1], pid[2]}, {}, rewriter);
rewriter.eraseOp(op);
return success();
}
assert(op.getNumOperands() == op.getIsSigned().size());
for (size_t i = 0; i < op.getNumOperands(); i++) {
bool isSigned = op.getIsSigned()[i] > 0;
auto elems = unpackLLElements(loc, adaptor.getOperands()[i], rewriter);
SmallVector<int, 8> dimWidths;
SmallVector<SmallVector<Value>> indices;
if (auto rankedTy =
dyn_cast<RankedTensorType>(op.getOperand(i).getType())) {
indices = emitIndices(loc, rewriter, targetInfo, rankedTy.getEncoding(),
rankedTy, true);
for (int64_t dim : rankedTy.getShape()) {
if (dim > 0) {
dimWidths.push_back(static_cast<int>(std::ceil(std::log10(dim))));
} else {
dimWidths.push_back(0);
}
}
} else {
assert(elems.size() == 1);
indices.push_back({});
}
if (!elems.empty()) {
printTensor(op.getPrefix(), i,
op.getNumOperands(), elems, pid, indices,
dimWidths, op.getHex(), rewriter, isSigned);
}
}
rewriter.eraseOp(op);
return success();
}
void printTensor(StringRef prefixStr, size_t operand, size_t numOperands,
ArrayRef<Value> elems, std::array<Value, 3> pid,
ArrayRef<SmallVector<Value>> indices,
ArrayRef<int> dimWidths, bool hex,
ConversionPatternRewriter &rewriter, bool isSigned) const {
assert(!elems.empty());
assert(elems.size() == indices.size());
assert(dimWidths.size() == indices.front().size());
size_t rank = dimWidths.size();
Value formatStrValue;
int formatStrByteCount = 0;
for (int i = 0; i < elems.size(); i++) {
std::string formatStr;
llvm::raw_string_ostream os(formatStr);
constexpr int kMaxPrintfOperands = 32;
SmallVector<Value, kMaxPrintfOperands> printfOperands;
os << "pid (";
for (int j = 0; j < pid.size(); j++) {
if (j != 0) {
os << ", ";
}
os << getFormatSubstr(pid[j]);
printfOperands.push_back(pid[j]);
}
os << ") ";
int maxAllowedRank = kMaxPrintfOperands - printfOperands.size() - 2;
os << "idx (";
const auto &index = indices[i];
for (size_t dim = 0; dim < index.size(); dim++) {
if (dim != 0) {
os << ", ";
}
if (dim == maxAllowedRank) {
os << "... (truncated)";
break;
}
os << getFormatSubstr(index[dim], false,
dimWidths[dim]);
printfOperands.push_back(index[dim]);
}
os << ")" << prefixStr;
if (numOperands > 1) {
os << "(operand " << operand << ") ";
}
auto elem = elems[i];
os << getFormatSubstr(elem, hex, std::nullopt, isSigned);
printfOperands.push_back(elem);
auto isSignedOperands =
llvm::SmallVector<bool>(printfOperands.size(), isSigned);
if (i == 0) {
formatStrValue = llPrintf(formatStr, printfOperands, isSignedOperands,
rewriter, &formatStrByteCount);
} else {
targetInfo.printf(rewriter, formatStrValue, formatStrByteCount,
printfOperands, isSignedOperands);
}
}
}
std::string getFormatSubstr(Value value, bool hex = false,
std::optional<int> width = std::nullopt,
bool isSigned = false) const {
Type type = value.getType();
if (isa<LLVM::LLVMPointerType>(type)) {
return "%p";
}
if (hex) {
std::string ret =
"0x%0" + std::to_string(type.getIntOrFloatBitWidth() / 4);
if (type.getIntOrFloatBitWidth() > 32) {
ret += "ll";
}
ret += "x";
return ret;
}
std::string prefix = "%";
if (width.has_value()) {
prefix += std::to_string(*width);
}
if (type.isBF16() || type.isF16() || type.isF32() || type.isF64()) {
return prefix + "f";
} else if (type.isInteger()) {
if (type.getIntOrFloatBitWidth() == 64)
return prefix + (isSigned ? "lli" : "llu");
else
return prefix + (isSigned ? "i" : "u");
}
assert(false && "not supported type");
return "";
}
Value llPrintf(StringRef msg, ValueRange args, ArrayRef<bool> isSigned,
ConversionPatternRewriter &rewriter,
int *formatStrByteCount = nullptr) const {
assert(!msg.empty() && "printf with empty string not supported");
llvm::SmallString<64> msgNewline(msg);
msgNewline.push_back('\n');
msgNewline.push_back('\0');
Value msgValue =
LLVM::addStringToModule(UnknownLoc::get(rewriter.getContext()),
rewriter, "printfFormat_", msgNewline);
targetInfo.printf(rewriter, msgValue, msgNewline.size_in_bytes(), args,
isSigned);
if (formatStrByteCount)
*formatStrByteCount = msgNewline.size_in_bytes();
return msgValue;
}
protected:
const TargetInfoBase &targetInfo;
};
}
void mlir::triton::populatePrintOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
patterns.add<PrintOpConversion>(typeConverter, targetInfo, benefit);
}