#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
namespace {
using namespace mlir;
using namespace mlir::triton;
struct GetProgramIdOpConversion
: public ConvertOpToLLVMPattern<triton::GetProgramIdOp> {
explicit GetProgramIdOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<triton::GetProgramIdOp>(typeConverter, benefit),
targetInfo(targetInfo) {}
LogicalResult
matchAndRewrite(triton::GetProgramIdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value programId = targetInfo.programId(
rewriter, op->getLoc(), op->getParentOfType<ModuleOp>(), op.getAxis());
rewriter.replaceOp(op, programId);
return success();
}
private:
const TargetInfoBase &targetInfo;
};
}
void mlir::triton::populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit) {
patterns.add<GetProgramIdOpConversion>(typeConverter, targetInfo, benefit);
}