#include "Dialect/TritonAMDGPU/IR/Dialect.h"
#include "TritonAMDGPUToLLVM/GCNAsmFormat.h"
#include "Utility.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "third_party/amd/include/Utils/Utility.h"
#include "triton/Analysis/Utility.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
using namespace mlir;
using namespace mlir::triton;
namespace {
struct ExtractSliceOpConversion
: public ConvertOpToLLVMPattern<amdgpu::ExtractSliceOp> {
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
LogicalResult processLayout(amdgpu::ExtractSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Location loc = op->getLoc();
auto srcTy = cast<RankedTensorType>(op.getSource().getType());
auto dstTy = cast<RankedTensorType>(op.getType());
auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter);
auto offsets = op.getStaticOffsets();
auto linearLayoutSrc = triton::gpu::toLinearLayout(srcTy);
auto outDimNames = llvm::to_vector(linearLayoutSrc.getOutDimNames());
auto linearLayoutDst =
triton::gpu::toLinearLayout(dstTy).transposeOuts(outDimNames);
auto ctx = rewriter.getContext();
int rank = srcTy.getRank();
StringAttr kReg = StringAttr::get(ctx, "register");
auto srcRegBases = linearLayoutSrc.getBases().lookup(kReg);
auto dstRegBases = linearLayoutDst.getBases().lookup(kReg);
using ElemLocationKey = decltype(linearLayoutSrc.apply({}));
llvm::MapVector<ElemLocationKey, unsigned> srcElemToReg;
int srcRegNum = 1 << srcRegBases.size();
for (int regId = 0; regId < srcRegNum; ++regId) {
SmallVector<std::pair<StringAttr, int32_t>> hardwareLocation;
for (auto dimName : linearLayoutSrc.getInDimNames()) {
if (dimName == kReg)
hardwareLocation.push_back({dimName, regId});
else
hardwareLocation.push_back({dimName, 0});
}
auto elemCoords = linearLayoutSrc.apply(hardwareLocation);
srcElemToReg[elemCoords] = regId;
}
int dstRegNum = 1 << dstRegBases.size();
SmallVector<Value> resultVals;
for (int regId = 0; regId < dstRegNum; ++regId) {
SmallVector<std::pair<StringAttr, int32_t>> hardwareLocation;
for (auto dimName : linearLayoutDst.getInDimNames()) {
if (dimName == kReg)
hardwareLocation.push_back({dimName, regId});
else
hardwareLocation.push_back({dimName, 0});
}
auto elemCoords = linearLayoutDst.apply(hardwareLocation);
for (int i = 0; i < rank; ++i)
elemCoords[i].second += offsets[i];
assert(srcElemToReg.contains(elemCoords));
auto srcRegId = srcElemToReg.lookup(elemCoords);
resultVals.push_back(vals[srcRegId]);
}
Value ret = packLLElements(loc, this->getTypeConverter(), resultVals,
rewriter, dstTy);
rewriter.replaceOp(op, ret);
return success();
}
LogicalResult
matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTy = op.getSource().getType();
return processLayout(op, adaptor, rewriter);
}
};
}
namespace mlir::triton::AMD {
void populateExtractSliceOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ExtractSliceOpConversion>(typeConverter, benefit);
}
}