#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
namespace mlir::triton {
#define GEN_PASS_DEF_CONVERTTRITONTOTRITONGPU
#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc"
}
namespace {
using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::gpu;
static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) {
for (const NamedAttribute attr : dictAttrs.getValue())
if (!op->hasAttr(attr.getName()))
op->setAttr(attr.getName(), attr.getValue());
}
template <class Op> struct GenericOpPattern : public OpConversionPattern<Op> {
using OpConversionPattern<Op>::OpConversionPattern;
LogicalResult
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> retTypes;
if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(),
retTypes)))
return failure();
rewriter.replaceOpWithNewOp<Op>(op, retTypes, adaptor.getOperands(),
op->getAttrs());
return success();
}
};
class ArithConstantPattern : public OpConversionPattern<arith::ConstantOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
auto retShapedType = cast<ShapedType>(retType);
auto value = dyn_cast<DenseElementsAttr>(adaptor.getValue());
if (isa<RankedTensorType>(retShapedType)) {
assert(value && "expected a dense elements attribute");
value = value.reshape(retShapedType);
}
addNamedAttrs(rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, retShapedType, value),
adaptor.getAttributes());
return success();
}
};
void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns,
TritonGPUConversionTarget &target) {
MLIRContext *context = patterns.getContext();
patterns.add<
ArithConstantPattern, GenericOpPattern<arith::AddIOp>,
GenericOpPattern<arith::SubIOp>, GenericOpPattern<arith::MulIOp>,
GenericOpPattern<arith::DivUIOp>, GenericOpPattern<arith::DivSIOp>,
GenericOpPattern<arith::CeilDivUIOp>,
GenericOpPattern<arith::CeilDivSIOp>,
GenericOpPattern<arith::FloorDivSIOp>, GenericOpPattern<arith::RemUIOp>,
GenericOpPattern<arith::RemSIOp>, GenericOpPattern<arith::AndIOp>,
GenericOpPattern<arith::OrIOp>, GenericOpPattern<arith::XOrIOp>,
GenericOpPattern<arith::ShLIOp>, GenericOpPattern<arith::ShRUIOp>,
GenericOpPattern<arith::ShRSIOp>,
GenericOpPattern<arith::AddFOp>, GenericOpPattern<arith::SubFOp>,
GenericOpPattern<arith::MaximumFOp>, GenericOpPattern<arith::MaxNumFOp>,
GenericOpPattern<arith::MaxSIOp>, GenericOpPattern<arith::MaxUIOp>,
GenericOpPattern<arith::MinimumFOp>, GenericOpPattern<arith::MinNumFOp>,
GenericOpPattern<arith::MinSIOp>, GenericOpPattern<arith::MinUIOp>,
GenericOpPattern<arith::MulFOp>, GenericOpPattern<arith::DivFOp>,
GenericOpPattern<arith::RemFOp>,
GenericOpPattern<arith::CmpIOp>, GenericOpPattern<arith::CmpFOp>,
GenericOpPattern<arith::SelectOp>,
GenericOpPattern<arith::TruncIOp>, GenericOpPattern<arith::TruncFOp>,
GenericOpPattern<arith::ExtUIOp>, GenericOpPattern<arith::ExtSIOp>,
GenericOpPattern<arith::ExtFOp>, GenericOpPattern<arith::SIToFPOp>,
GenericOpPattern<arith::FPToSIOp>, GenericOpPattern<arith::FPToUIOp>,
GenericOpPattern<arith::UIToFPOp>>(typeConverter, context);
}
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns,
TritonGPUConversionTarget &target) {
MLIRContext *context = patterns.getContext();
patterns.add<GenericOpPattern<math::ExpOp>, GenericOpPattern<math::Exp2Op>,
GenericOpPattern<math::FloorOp>, GenericOpPattern<math::CeilOp>,
GenericOpPattern<math::CosOp>, GenericOpPattern<math::SinOp>,
GenericOpPattern<math::LogOp>, GenericOpPattern<math::Log2Op>,
GenericOpPattern<math::ErfOp>, GenericOpPattern<math::AbsFOp>,
GenericOpPattern<math::AbsIOp>, GenericOpPattern<math::SqrtOp>,
GenericOpPattern<math::RsqrtOp>, GenericOpPattern<math::FmaOp>>(
typeConverter, context);
}
struct TritonExpandDimsPattern
: public OpConversionPattern<triton::ExpandDimsOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::ExpandDimsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType argType =
cast<RankedTensorType>(adaptor.getSrc().getType());
Attribute _argEncoding = argType.getEncoding();
if (!_argEncoding)
return failure();
auto argEncoding = cast<triton::gpu::BlockedEncodingAttr>(_argEncoding);
auto retShape = argType.getShape().vec();
retShape.insert(retShape.begin() + op.getAxis(), 1);
auto retSizePerThread = llvm::to_vector(argEncoding.getSizePerThread());
retSizePerThread.insert(retSizePerThread.begin() + op.getAxis(), 1);
auto retThreadsPerWarp = to_vector(argEncoding.getThreadsPerWarp());
retThreadsPerWarp.insert(retThreadsPerWarp.begin() + op.getAxis(), 1);
auto retWarpsPerCTA = to_vector(argEncoding.getWarpsPerCTA());
retWarpsPerCTA.insert(retWarpsPerCTA.begin() + op.getAxis(), 1);
SmallVector<unsigned, 4> retOrder(retShape.size());
std::iota(retOrder.begin(), retOrder.end(), 0);
auto argCTALayout = argEncoding.getCTALayout();
auto retCTAsPerCGA = insertOne(argCTALayout.getCTAsPerCGA(), op.getAxis());
auto retCTASplitNum =
insertOne(argCTALayout.getCTASplitNum(), op.getAxis());
auto retCTAOrder = insertOrder(argCTALayout.getCTAOrder(), op.getAxis());
auto retCTALayout = triton::gpu::CTALayoutAttr::get(
getContext(), retCTAsPerCGA, retCTASplitNum, retCTAOrder);
triton::gpu::BlockedEncodingAttr retEncoding =
triton::gpu::BlockedEncodingAttr::get(getContext(), retSizePerThread,
retThreadsPerWarp, retWarpsPerCTA,
retOrder, retCTALayout);
Attribute newArgEncoding = triton::gpu::SliceEncodingAttr::get(
getContext(), op.getAxis(), retEncoding);
RankedTensorType newArgType = argType.cloneWithEncoding(newArgEncoding);
auto newSrc = rewriter.create<triton::gpu::ConvertLayoutOp>(
op.getLoc(), newArgType, adaptor.getSrc());
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
op, newSrc, adaptor.getAxis()),
adaptor.getAttributes());
return success();
}
private:
template <typename T>
SmallVector<T> insertOne(ArrayRef<T> vec, unsigned axis) const {
SmallVector<T> res(vec.begin(), vec.end());
res.insert(res.begin() + axis, 1);
return res;
}
SmallVector<unsigned> insertOrder(ArrayRef<unsigned> order,
unsigned axis) const {
SmallVector<unsigned> resOrder(order.begin(), order.end());
for (unsigned i = 0; i < resOrder.size(); ++i)
if (resOrder[i] >= axis)
++resOrder[i];
resOrder.insert(resOrder.begin(), axis);
return resOrder;
}
};
struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
RankedTensorType origType = op.getType();
auto origShape = origType.getShape();
auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
int numWarps = typeConverter->getNumWarps();
int threadsPerWarp = typeConverter->getThreadsPerWarp();
int numCTAs = typeConverter->getNumCTAs();
auto rank = origShape.size();
SmallVector<unsigned> retSizePerThread(rank, 1);
auto numElements = product<int64_t>(origShape);
if (numElements / (numWarps * threadsPerWarp) >= 4) {
retSizePerThread[rank - 1] = 2;
retSizePerThread[rank - 2] = 2;
}
if (numElements / (numWarps * threadsPerWarp) >= 16) {
retSizePerThread[rank - 1] = 4;
retSizePerThread[rank - 2] = 4;
}
retSizePerThread[rank - 1] = std::min(
retSizePerThread[rank - 1], static_cast<unsigned>(origShape[rank - 1]));
retSizePerThread[rank - 2] = std::min(
retSizePerThread[rank - 2], static_cast<unsigned>(origShape[rank - 2]));
SmallVector<unsigned> retOrder(rank);
for (unsigned i = 0; i < rank; ++i)
retOrder[i] = rank - 1 - i;
Attribute dEncoding = triton::gpu::BlockedEncodingAttr::get(
getContext(), origShape, retSizePerThread, retOrder, numWarps,
threadsPerWarp, numCTAs);
RankedTensorType retType = origType.cloneWithEncoding(dEncoding);
auto aType = cast<RankedTensorType>(adaptor.getA().getType());
auto bType = cast<RankedTensorType>(adaptor.getB().getType());
Type aEltType = aType.getElementType();
Type bEltType = bType.getElementType();
Attribute aEncoding = aType.getEncoding();
Attribute bEncoding = bType.getEncoding();
if (!aEncoding || !bEncoding)
return failure();
Value a = adaptor.getA();
Value b = adaptor.getB();
Value c = adaptor.getC();
if (!mlir::isa<triton::gpu::DotOperandEncodingAttr>(aEncoding)) {
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
getContext(), 0, dEncoding, aEltType);
auto dstType = aType.cloneWithEncoding(encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!mlir::isa<triton::gpu::DotOperandEncodingAttr>(bEncoding)) {
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
getContext(), 1, dEncoding, bEltType);
auto dstType = bType.cloneWithEncoding(encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::DotOp>(
op, retType, a, b, c, adaptor.getInputPrecision(),
adaptor.getMaxNumImpreciseAcc()),
adaptor.getAttributes());
return success();
}
};
struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto retType = cast<RankedTensorType>(
this->getTypeConverter()->convertType(op.getType()));
auto retEncoding =
cast<triton::gpu::BlockedEncodingAttr>(retType.getEncoding());
auto lhsType = adaptor.getLhs().getType();
auto rhsType = adaptor.getRhs().getType();
auto lhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(lhsType);
auto rhsTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(rhsType);
auto retTotalElemsPerThread = triton::gpu::getTotalElemsPerThread(retType);
auto retShape = retType.getShape();
auto retOrder = retEncoding.getOrder();
auto retThreadsPerWarp = retEncoding.getThreadsPerWarp();
auto retWarpsPerCTA = retEncoding.getWarpsPerCTA();
auto newRetTotalElemsPerThread =
nextPowOf2(lhsTotalElemsPerThread + rhsTotalElemsPerThread);
auto newRetSizePerThread = llvm::to_vector(retEncoding.getSizePerThread());
newRetSizePerThread[retOrder[0]] *=
newRetTotalElemsPerThread / retTotalElemsPerThread;
triton::gpu::BlockedEncodingAttr newRetEncoding =
triton::gpu::BlockedEncodingAttr::get(
getContext(), newRetSizePerThread, retThreadsPerWarp,
retWarpsPerCTA, retOrder, retEncoding.getCTALayout());
auto newRetType = retType.cloneWithEncoding(newRetEncoding);
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
op, newRetType, adaptor.getOperands()),
adaptor.getAttributes());
return success();
}
};
struct TritonJoinOpPattern : public OpConversionPattern<triton::JoinOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(JoinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::JoinOp>(
op, adaptor.getLhs(), adaptor.getRhs()),
adaptor.getAttributes());
return success();
}
};
struct TritonSplitOpPattern : public OpConversionPattern<triton::SplitOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult matchAndRewrite(SplitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto src = adaptor.getSrc();
auto srcTy = cast<RankedTensorType>(src.getType());
auto srcEnc = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
int rank = srcEnc.getOrder().size();
auto typeConverter = getTypeConverter<TritonGPUTypeConverter>();
if (!srcEnc || srcEnc.getSizePerThread().back() != 2 ||
srcEnc.getOrder().front() != rank - 1) {
auto defaultEnc = getDefaultBlockedEncoding(
getContext(),
cast<RankedTensorType>(op.getResult(0).getType()).getShape(),
typeConverter->getNumWarps(), typeConverter->getThreadsPerWarp(),
typeConverter->getNumCTAs());
auto append = [&](ArrayRef<unsigned> vals, unsigned val) {
SmallVector<unsigned> res(vals);
res.push_back(val);
return res;
};
auto prepend = [&](ArrayRef<unsigned> vals, unsigned val) {
SmallVector<unsigned> res;
res.push_back(val);
res.append(vals.begin(), vals.end());
return res;
};
srcEnc = BlockedEncodingAttr::get(
getContext(), append(defaultEnc.getSizePerThread(), 2),
append(defaultEnc.getThreadsPerWarp(), 1),
append(defaultEnc.getWarpsPerCTA(), 1),
prepend(defaultEnc.getOrder(), rank - 1),
CTALayoutAttr::get(getContext(),
append(defaultEnc.getCTAsPerCGA(), 1),
append(defaultEnc.getCTASplitNum(), 1),
prepend(defaultEnc.getCTAOrder(), rank - 1)));
srcTy = srcTy.cloneWithEncoding(srcEnc);
src = rewriter.create<ConvertLayoutOp>(op.getLoc(), srcTy, src);
}
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::SplitOp>(op, src),
adaptor.getAttributes());
return success();
}
};
struct TritonTransPattern : public OpConversionPattern<TransOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value src = adaptor.getSrc();
auto srcTy = cast<RankedTensorType>(src.getType());
auto srcEnc = srcTy.getEncoding();
if (!srcEnc)
return failure();
addNamedAttrs(rewriter.replaceOpWithNewOp<TransOp>(op, src, op.getOrder()),
adaptor.getAttributes());
return success();
}
};
struct TritonBroadcastPattern
: public OpConversionPattern<triton::BroadcastOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcType = cast<RankedTensorType>(adaptor.getSrc().getType());
auto srcEncoding = srcType.getEncoding();
if (!srcEncoding)
return failure();
Type retType = op.getType().cloneWithEncoding(srcEncoding);
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
op, retType, adaptor.getOperands()),
adaptor.getAttributes());
return success();
}
};
struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newReduce = rewriter.create<triton::ReduceOp>(
op.getLoc(), adaptor.getOperands(), adaptor.getAxis());
addNamedAttrs(newReduce, adaptor.getAttributes());
auto &newCombineOp = newReduce.getCombineOp();
rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp,
newCombineOp.end());
rewriter.replaceOp(op, newReduce.getResult());
return success();
}
};
struct TritonScanPattern : public OpConversionPattern<triton::ScanOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::ScanOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newScan = rewriter.create<triton::ScanOp>(
op.getLoc(), adaptor.getOperands(), adaptor.getAxis(), op.getReverse());
addNamedAttrs(newScan, adaptor.getAttributes());
auto &newCombineOp = newScan.getCombineOp();
rewriter.cloneRegionBefore(op.getCombineOp(), newCombineOp,
newCombineOp.end());
rewriter.replaceOp(op, newScan.getResult());
return success();
}
};
struct TritonMapElementwisePattern
: public OpConversionPattern<triton::MapElementwiseOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::MapElementwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = getTypeConverter();
SmallVector<Type> resultTys;
auto err = converter->convertTypes(op.getResults().getType(), resultTys);
if (failed(err)) {
return err;
}
auto newMapOp = rewriter.create<triton::MapElementwiseOp>(
op.getLoc(), resultTys, adaptor.getOperands(), op.getPack());
addNamedAttrs(newMapOp, adaptor.getAttributes());
auto &newScalarOp = newMapOp.getScalarOp();
rewriter.cloneRegionBefore(op.getScalarOp(), newScalarOp,
newScalarOp.end());
rewriter.replaceOp(op, newMapOp.getResult());
return success();
}
};
class TritonFuncOpPattern : public OpConversionPattern<triton::FuncOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = getTypeConverter();
TypeConverter::SignatureConversion result(op.getNumArguments());
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
op, op.getName(), op.getFunctionType());
addNamedAttrs(newOp, adaptor.getAttributes());
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),
newOp.getBody().end());
if (!newOp.getBody().empty())
rewriter.applySignatureConversion(&newOp.getBody().front(), result,
converter);
return success();
}
};
class TritonCallOpPattern : public OpConversionPattern<triton::CallOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<triton::CallOp>(
op, op.getCallee(), op.getResultTypes(), adaptor.getOperands());
addNamedAttrs(newOp, adaptor.getAttributes());
return success();
}
};
class TritonReturnOpPattern : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
return success();
}
};
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns, unsigned numCTAs) {
MLIRContext *context = patterns.getContext();
patterns.insert<
GenericOpPattern<triton::AdvanceOp>,
GenericOpPattern<triton::MakeTensorPtrOp>,
GenericOpPattern<triton::ReshapeOp>,
GenericOpPattern<triton::BitcastOp>,
GenericOpPattern<triton::FpToFpOp>,
GenericOpPattern<triton::IntToPtrOp>,
GenericOpPattern<triton::PtrToIntOp>,
GenericOpPattern<triton::SplatOp>,
GenericOpPattern<triton::UnsplatOp>,
GenericOpPattern<triton::AddPtrOp>,
TritonBroadcastPattern,
TritonCatPattern,
TritonJoinOpPattern,
TritonSplitOpPattern,
GenericOpPattern<triton::ClampFOp>,
GenericOpPattern<triton::PreciseSqrtOp>,
GenericOpPattern<triton::PreciseDivFOp>,
GenericOpPattern<triton::MulhiUIOp>,
GenericOpPattern<triton::ElementwiseInlineAsmOp>,
TritonReducePattern,
GenericOpPattern<triton::ReduceReturnOp>,
TritonScanPattern,
GenericOpPattern<triton::ScanReturnOp>,
GenericOpPattern<triton::MakeRangeOp>,
TritonExpandDimsPattern,
TritonTransPattern,
TritonDotPattern,
TritonMapElementwisePattern,
GatherScatterOpPattern<DescriptorGatherOp>,
GatherScatterOpPattern<DescriptorScatterOp>,
GenericOpPattern<triton::LoadOp>,
GenericOpPattern<triton::StoreOp>,
GenericOpPattern<triton::HistogramOp>,
GenericOpPattern<triton::GatherOp>,
GenericOpPattern<triton::ExternElementwiseOp>,
GenericOpPattern<triton::PrintOp>,
GenericOpPattern<triton::AssertOp>,
GenericOpPattern<triton::AtomicCASOp>,
GenericOpPattern<triton::AtomicRMWOp>,
GenericOpPattern<triton::DescriptorLoadOp>,
GenericOpPattern<triton::DescriptorStoreOp>,
GenericOpPattern<triton::DescriptorReduceOp>,
GenericOpPattern<triton::DotScaledOp>,
GenericOpPattern<triton::CallOp>,
GenericOpPattern<ReturnOp>,
TritonFuncOpPattern
>(typeConverter, context);
}
struct SCFForPattern : public OpConversionPattern<scf::ForOp> {
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(scf::ForOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp =
cast<scf::ForOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
newOp.getRegion().end());
if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
*getTypeConverter()))) {
return rewriter.notifyMatchFailure(op, "could not convert body types");
}
newOp->setOperands(adaptor.getOperands());
SmallVector<Type> newResultTypes;
for (Type type : op.getResultTypes()) {
Type newType = typeConverter->convertType(type);
if (!newType)
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
newResultTypes.push_back(newType);
}
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
std::get<0>(t).setType(std::get<1>(t));
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
class SCFIfPattern : public OpConversionPattern<scf::IfOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(scf::IfOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
SmallVector<Type> newResultTypes;
for (auto type : op.getResultTypes()) {
Type newType = typeConverter->convertType(type);
if (!newType)
return rewriter.notifyMatchFailure(op, "not a 1:1 type conversion");
newResultTypes.push_back(newType);
}
scf::IfOp newOp =
cast<scf::IfOp>(rewriter.cloneWithoutRegions(*op.getOperation()));
rewriter.inlineRegionBefore(op.getThenRegion(), newOp.getThenRegion(),
newOp.getThenRegion().end());
rewriter.inlineRegionBefore(op.getElseRegion(), newOp.getElseRegion(),
newOp.getElseRegion().end());
newOp->setOperands(adaptor.getOperands());
for (auto t : llvm::zip(newOp.getResults(), newResultTypes))
std::get<0>(t).setType(std::get<1>(t));
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
class SCFWhilePattern : public OpConversionPattern<scf::WhileOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto *converter = getTypeConverter();
assert(converter);
SmallVector<Type> newResultTypes;
if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes)))
return failure();
auto newOp = rewriter.create<scf::WhileOp>(op.getLoc(), newResultTypes,
adaptor.getOperands());
for (auto i : {0u, 1u}) {
auto &dstRegion = newOp.getRegion(i);
rewriter.inlineRegionBefore(op.getRegion(i), dstRegion, dstRegion.end());
if (failed(rewriter.convertRegionTypes(&dstRegion, *converter)))
return rewriter.notifyMatchFailure(op, "could not convert body types");
}
rewriter.replaceOp(op, newOp.getResults());
return success();
}
};
class SCFConditionPattern : public OpConversionPattern<scf::ConditionOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.modifyOpInPlace(op,
[&]() { op->setOperands(adaptor.getOperands()); });
return success();
}
};
void populateSCFPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<GenericOpPattern<scf::YieldOp>, SCFForPattern, SCFIfPattern,
SCFWhilePattern, SCFConditionPattern>(typeConverter, context);
}
class CFBranchPattern : public OpConversionPattern<cf::BranchOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = getTypeConverter();
auto newOp = rewriter.replaceOpWithNewOp<cf::BranchOp>(
op, op.getSuccessor(), adaptor.getOperands());
if (failed(rewriter.convertRegionTypes(newOp.getSuccessor()->getParent(),
*converter)))
return failure();
return success();
}
};
class CFCondBranchPattern : public OpConversionPattern<cf::CondBranchOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(cf::CondBranchOp op, cf::CondBranchOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = getTypeConverter();
auto newOp = rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
op, adaptor.getCondition(), op.getTrueDest(),
adaptor.getTrueDestOperands(), op.getFalseDest(),
adaptor.getFalseDestOperands());
addNamedAttrs(newOp, adaptor.getAttributes());
if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(),
*converter)))
return failure();
if (failed(rewriter.convertRegionTypes(newOp.getFalseDest()->getParent(),
*converter)))
return failure();
return success();
}
};
void populateCFPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<CFCondBranchPattern, CFBranchPattern>(typeConverter, context);
}
class ConvertTritonToTritonGPU
: public triton::impl::ConvertTritonToTritonGPUBase<
ConvertTritonToTritonGPU> {
public:
using ConvertTritonToTritonGPUBase::ConvertTritonToTritonGPUBase;
void runOnOperation() override {
if (target.getValue().empty()) {
mlir::emitError(
getOperation().getLoc(),
"'convert-triton-to-tritongpu' requires 'target' option to be set");
return signalPassFailure();
}
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
TritonGPUTypeConverter typeConverter(context, numWarps, threadsPerWarp,
numCTAs, enableSourceRemat);
TritonGPUConversionTarget target(*context, typeConverter);
RewritePatternSet patterns(context);
populateArithPatternsAndLegality(typeConverter, patterns, target);
populateMathPatternsAndLegality(typeConverter, patterns, target);
populateTritonPatterns(typeConverter, patterns, numCTAs);
populateSCFPatterns(typeConverter, patterns);
populateCFPatterns(typeConverter, patterns);
patterns.insert<GenericOpPattern<ub::PoisonOp>>(typeConverter, context);
Builder b(&getContext());
mod->setAttr(AttrNumWarpsName, b.getI32IntegerAttr(numWarps));
mod->setAttr(AttrNumThreadsPerWarp, b.getI32IntegerAttr(threadsPerWarp));
mod->setAttr(AttrNumCTAsName, b.getI32IntegerAttr(numCTAs));
mod->setAttr(AttrTargetName, b.getStringAttr(this->target.getValue()));
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
return signalPassFailure();
}
};
}