#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
namespace mlir {
namespace triton {
namespace gpu {
#define GEN_PASS_DEF_TRITONGPUCOALESCEASYNCCOPY
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
struct ClipAsyncCopySizePerThread
: public OpRewritePattern<AsyncCopyGlobalToLocalOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AsyncCopyGlobalToLocalOp copyOp,
PatternRewriter &rewriter) const override {
Value src = copyOp.getSrc();
Value mask = copyOp.getMask();
Value other = copyOp.getOther();
auto srcTy = cast<RankedTensorType>(src.getType());
auto dstTy = cast<MemDescType>(copyOp.getResult().getType());
auto blockedEnc = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
if (!blockedEnc)
return rewriter.notifyMatchFailure(copyOp,
"src must be of blocked encoding");
auto sharedEnc = dyn_cast<SwizzledSharedEncodingAttr>(dstTy.getEncoding());
if (!sharedEnc)
return failure();
auto sharedVec = sharedEnc.getVec();
LinearLayout regLayout = triton::gpu::toLinearLayout(srcTy);
LinearLayout sharedLayout = triton::gpu::toLinearLayout(dstTy);
auto copyContigSize =
regLayout.invertAndCompose(sharedLayout).getNumConsecutiveInOut();
auto contigPerThread = getContigPerThread(srcTy);
auto blockContigSize = contigPerThread[blockedEnc.getOrder()[0]];
if (blockContigSize <= copyContigSize)
return rewriter.notifyMatchFailure(
copyOp,
"blocked sizePerThread along contiguous dim must be greater than the "
"max contiguous copy size ");
contigPerThread[blockedEnc.getOrder()[0]] = copyContigSize;
auto mod = copyOp->getParentOfType<ModuleOp>();
int numWarps = lookupNumWarps(copyOp);
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
auto newBlockEnc = BlockedEncodingAttr::get(
copyOp.getContext(), srcTy.getShape(), contigPerThread,
blockedEnc.getOrder(), numWarps, threadsPerWarp,
blockedEnc.getCTALayout());
auto convertBlockLayout = [&](Value src, BlockedEncodingAttr enc) {
auto ty = cast<RankedTensorType>(src.getType());
auto newTy = ty.cloneWithEncoding(enc);
auto cvt = rewriter.create<ConvertLayoutOp>(copyOp->getLoc(), newTy, src);
return cvt.getResult();
};
src = convertBlockLayout(src, newBlockEnc);
if (mask)
mask = convertBlockLayout(mask, newBlockEnc);
if (other)
other = convertBlockLayout(other, newBlockEnc);
rewriter.modifyOpInPlace(copyOp, [&]() {
copyOp.getSrcMutable().assign(src);
if (mask)
copyOp.getMaskMutable().assign(mask);
if (other)
copyOp.getOtherMutable().assign(other);
});
return success();
}
};
struct CoalesceAsyncCopyPass
: impl::TritonGPUCoalesceAsyncCopyBase<CoalesceAsyncCopyPass> {
using Base::Base;
void runOnOperation() override {
ModuleOp m = getOperation();
MLIRContext *context = &getContext();
mlir::RewritePatternSet patterns(context);
patterns.add<ClipAsyncCopySizePerThread>(context);
if (failed(applyPatternsGreedily(m, std::move(patterns))))
signalPassFailure();
}
};
}
}
}