#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include <algorithm>
#include <numeric>
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
using namespace mlir;
using namespace mlir::triton::gpu;
TritonGPUTypeConverter::TritonGPUTypeConverter(MLIRContext *context,
int numWarps, int threadsPerWarp,
int numCTAs,
bool enableSourceRemat)
: context(context), numWarps(numWarps), threadsPerWarp(threadsPerWarp),
numCTAs(numCTAs) {
addConversion([](Type type) { return type; });
addConversion([this](RankedTensorType tensorType) -> RankedTensorType {
if (tensorType.getEncoding())
return tensorType;
ArrayRef<int64_t> shape = tensorType.getShape();
triton::gpu::BlockedEncodingAttr encoding =
getDefaultBlockedEncoding(this->context, shape, this->numWarps,
this->threadsPerWarp, this->numCTAs);
return tensorType.cloneWithEncoding(encoding);
});
addConversion([this](triton::PointerType ptrType) -> triton::PointerType {
auto pointeeTensorType =
dyn_cast<RankedTensorType>(ptrType.getPointeeType());
if (pointeeTensorType == nullptr)
return ptrType;
auto convertedTensorType = convertType(pointeeTensorType);
return triton::PointerType::get(convertedTensorType,
ptrType.getAddressSpace());
});
if (enableSourceRemat) {
addSourceMaterialization([](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) -> Value {
return builder.create<UnrealizedConversionCastOp>(loc, tensorType, inputs)
.getResult(0);
});
}
addTargetMaterialization([](OpBuilder &builder, RankedTensorType tensorType,
ValueRange inputs, Location loc) {
auto cast =
builder.create<triton::gpu::ConvertLayoutOp>(loc, tensorType, inputs);
return cast.getResult();
});
}
TritonGPUConversionTarget::TritonGPUConversionTarget(
MLIRContext &context, TritonGPUTypeConverter &typeConverter)
: ConversionTarget(context) {
addLegalDialect<triton::gpu::TritonGPUDialect>();
addIllegalOp<scf::ExecuteRegionOp, scf::ParallelOp, scf::ReduceOp,
scf::ReduceReturnOp>();
addDynamicallyLegalDialect<arith::ArithDialect, math::MathDialect,
triton::TritonDialect, cf::ControlFlowDialect,
scf::SCFDialect, ub::UBDialect>(
[&](Operation *op) { return isDynamicallyLegal(op, typeConverter); });
addDynamicallyLegalOp<triton::DotOp>([](triton::DotOp dotOp) -> bool {
Attribute aEncoding =
cast<RankedTensorType>(dotOp.getA().getType()).getEncoding();
Attribute bEncoding =
cast<RankedTensorType>(dotOp.getB().getType()).getEncoding();
if (aEncoding && isa<triton::gpu::DotOperandEncodingAttr>(aEncoding) &&
bEncoding && isa<triton::gpu::DotOperandEncodingAttr>(bEncoding))
return true;
return false;
});
addDynamicallyLegalOp<triton::FuncOp>([](triton::FuncOp funcOp) -> bool {
for (auto arg : funcOp.getArguments()) {
if (auto tensor = dyn_cast<RankedTensorType>(arg.getType())) {
if (!tensor.getEncoding())
return false;
}
}
return true;
});
}
bool TritonGPUConversionTarget::isDynamicallyLegal(
Operation *op, const TypeConverter &typeConverter) {
bool hasLegalRegions = true;
for (auto ®ion : op->getRegions()) {
hasLegalRegions = hasLegalRegions && typeConverter.isLegal(®ion);
}
if (hasLegalRegions && typeConverter.isLegal(op)) {
return true;
}
return false;
}
static RankedTensorType getNewIndicesType(RankedTensorType type,
unsigned numThreads,
unsigned numWarps) {
assert(type.getRank() == 1);
auto enc = cast<DistributedEncodingTrait>(type.getEncoding());
std::array<unsigned, 2> sizePerThread{1, 4};
std::array<unsigned, 2> threadsPerWarp = {numThreads, 1};
std::array<unsigned, 2> order = {1, 0};
std::array<unsigned, 2> warpsPerCta = {1, numWarps};
MLIRContext *ctx = type.getContext();
auto ctaLayout = CTALayoutAttr::getDefault(ctx, 2);
auto parentEncoding = BlockedEncodingAttr::get(
ctx, sizePerThread, threadsPerWarp, warpsPerCta, order, ctaLayout);
auto newEncoding = SliceEncodingAttr::get(ctx, 0, parentEncoding);
if (enc == newEncoding)
return {};
return type.cloneWithEncoding(newEncoding);
}
static LogicalResult convertGatherScatterIndices(Operation *op,
OpOperand &indices,
ConversionPatternRewriter &b) {
auto type = cast<RankedTensorType>(indices.get().getType());
RankedTensorType newType =
getNewIndicesType(type, lookupThreadsPerWarp(b), lookupNumWarps(op));
if (!newType)
return failure();
Value index = b.create<ConvertLayoutOp>(op->getLoc(), newType, indices.get());
indices.set(index);
return success();
}
LogicalResult impl::convertGatherScatterOp(
Operation *op, ValueRange operands, OpOperand &xOffsetsMutable,
const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter) {
LogicalResult result = success();
rewriter.modifyOpInPlace(op, [&] {
for (auto [operand, value] : llvm::zip(op->getOpOperands(), operands))
operand.set(value);
for (OpResult result : op->getOpResults())
result.setType(typeConverter.convertType(result.getType()));
result = convertGatherScatterIndices(op, xOffsetsMutable, rewriter);
});
return result;
}