#include "TritonAMDGPUTransforms/Passes.h"
#include "mlir/Pass/PassManager.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace mlir {
#define GEN_PASS_DEF_TRITONAMDGPUHOISTLAYOUTCONVERSIONS
#include "TritonAMDGPUTransforms/Passes.h.inc"
namespace {
static void hoistCvtDotOpOutOfLoop(ttg::ConvertLayoutOp cvtOp) {
RankedTensorType rtType = dyn_cast<RankedTensorType>(cvtOp.getType());
if (!rtType)
return;
Attribute encoding = rtType.getEncoding();
if (!encoding)
return;
if (!isa<ttg::DotOperandEncodingAttr>(encoding))
return;
auto srcDefOp = cvtOp.getSrc().getDefiningOp();
if (srcDefOp) {
scf::ForOp parentForOp = cvtOp->getParentOfType<scf::ForOp>();
if (parentForOp && !parentForOp->isAncestor(srcDefOp)) {
cvtOp->moveAfter(srcDefOp);
}
}
}
}
struct TritonAMDGPUHoistLayoutConversionsPass
: public impl::TritonAMDGPUHoistLayoutConversionsBase<
TritonAMDGPUHoistLayoutConversionsPass> {
void runOnOperation() override {
tt::FuncOp funcOp = getOperation();
SmallVector<ttg::ConvertLayoutOp> cvtOps;
funcOp.walk([&](ttg::ConvertLayoutOp cvtOp) { cvtOps.push_back(cvtOp); });
for (auto cvtOp : cvtOps)
hoistCvtDotOpOutOfLoop(cvtOp);
}
};
}