#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "triton/Analysis/AxisInfo.h"
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "llvm/ADT/ScopeExit.h"
using namespace mlir;
using namespace triton;
using namespace triton::gpu;
namespace ttng = triton::nvidia_gpu;
using RunPipelineFn = function_ref<LogicalResult(OpPassManager &, ModuleOp)>;
static OwningOpRef<ModuleOp> takeIntoFunction(ModuleAxisInfoAnalysis &axisInfo,
Region *partition, int numWarps) {
ModuleOp mod = axisInfo.getModuleOp();
OwningOpRef<ModuleOp> container = ModuleOp::create(mod.getLoc());
Block *containerBlock = container->getBody();
auto b = OpBuilder::atBlockBegin(containerBlock);
FunctionType funcType = b.getFunctionType(partition->getArgumentTypes(), {});
auto containerFunc = b.create<FuncOp>(mod.getLoc(), "container", funcType);
containerFunc.getBody().takeBody(*partition);
container.get()->setAttrs(mod->getAttrs());
container.get()->setAttr(AttrNumWarpsName, b.getI32IntegerAttr(numWarps));
containerFunc.walk([&](WarpReturnOp op) {
b.setInsertionPoint(op);
b.create<ReturnOp>(op.getLoc());
op.erase();
});
if (failed(mlir::verify(*container)))
llvm::report_fatal_error("expected partition region to make valid IR");
auto wsOp = partition->getParentOfType<WarpSpecializeOp>();
auto *funcInfo =
axisInfo.getFuncData(wsOp->getParentOfType<FunctionOpInterface>());
assert(funcInfo && "expected to find function axis info");
for (auto [i, capture] : llvm::enumerate(wsOp.getExplicitCaptures())) {
AxisInfo info = funcInfo->lookup(capture);
containerFunc.setArgAttr(i, "tt.contiguity",
b.getI64IntegerAttr(info.getContiguity(0)));
containerFunc.setArgAttr(i, "tt.divisibility",
b.getI64IntegerAttr(info.getDivisibility(0)));
containerFunc.setArgAttr(i, "tt.constancy",
b.getI64IntegerAttr(info.getConstancy(0)));
}
return container;
}
static void extractPartitionBody(OwningOpRef<ModuleOp> container,
Region *partition) {
auto containerFunc = cast<FuncOp>(container->lookupSymbol("container"));
containerFunc.walk([](ReturnOp op) {
OpBuilder b(op);
b.create<WarpReturnOp>(op.getLoc());
op.erase();
});
partition->takeBody(containerFunc.getBody());
}
static LogicalResult relayoutWarps(ModuleAxisInfoAnalysis &axisInfo,
Region *partition, int prevNumWarps,
int newNumWarps, RunPipelineFn runPipeline) {
OwningOpRef<ModuleOp> container =
takeIntoFunction(axisInfo, partition, prevNumWarps);
mlir::AttrTypeReplacer replacer;
replacer.addReplacement(
[](RankedTensorType ty) { return ty.cloneWithEncoding({}); });
replacer.addReplacement([](TensorDescType ty) -> std::pair<Type, WalkResult> {
return {ty, WalkResult::skip()};
});
replacer.recursivelyReplaceElementsIn(*container, false,
false,
true);
ModuleOp mod = axisInfo.getModuleOp();
auto target = mod->getAttrOfType<StringAttr>(AttrTargetName);
if (!target)
return mlir::emitError(mod.getLoc(), "module missing target specification");
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(mod);
int numCTAs = TritonGPUDialect::getNumCTAs(mod);
OpPassManager pm;
pm.addPass(
createConvertTritonToTritonGPU({target.str(), newNumWarps, threadsPerWarp,
numCTAs, true}));
pm.addPass(createRelayoutTritonGPU());
if (failed(runPipeline(pm, *container)))
return failure();
container->walk([](UnrealizedConversionCastOp op) {
op.getResult(0).replaceAllUsesWith(op.getOperand(0));
op.erase();
});
pm.clear();
pm.addPass(createTritonGPUCoalesce());
pm.addPass(createTritonGPURemoveLayoutConversions());
pm.addPass(createTritonGPUOptimizeThreadLocality());
pm.addPass(createTritonGPUAccelerateMatmul());
pm.addPass(createTritonGPURemoveLayoutConversions());
if (failed(runPipeline(pm, *container)))
return failure();
extractPartitionBody(std::move(container), partition);
return success();
}
static unsigned getTensorNumI32Regs(RankedTensorType ty) {
unsigned numElems = getTotalElemsPerThread(ty) *
product(getThreadsPerWarp(ty)) *
product(getWarpsPerCTA(ty));
unsigned elSize =
isa<PointerType>(ty.getElementType()) ? 64 : ty.getElementTypeBitWidth();
return numElems * elSize / 32;
}
static LogicalResult optimizePartitionNumWarps(ModuleAxisInfoAnalysis &axisInfo,
WarpSpecializeOp wsOp,
RunPipelineFn runPipeline) {
SmallVector<unsigned> maxTensorRegs;
for (Region *partition : wsOp.getPartitionRegions()) {
unsigned &tensorRegs = maxTensorRegs.emplace_back(0);
partition->walk([&](Operation *op) {
for (Type type :
llvm::concat<Type>(op->getOperandTypes(), op->getResultTypes())) {
if (auto tensor = dyn_cast<RankedTensorType>(type))
tensorRegs = std::max(tensorRegs, getTensorNumI32Regs(tensor));
}
});
tensorRegs *= 2;
}
constexpr unsigned nTotalRegs = 1 << 16;
const unsigned threadsPerWarp =
TritonGPUDialect::getThreadsPerWarp(axisInfo.getModuleOp());
const unsigned defaultNumWarps = lookupNumWarps(wsOp);
SmallVector<int32_t> partitionNumWarps =
llvm::to_vector(wsOp.getPartitionNumWarps());
SmallVector<int32_t> minWarpsForPartition(partitionNumWarps.size(), 1);
for (auto [minWarps, region] :
llvm::zip(minWarpsForPartition, wsOp.getPartitionRegions())) {
region->walk([minWarps = &minWarps](Operation *op) {
if (isa<ttng::AsyncTMAGatherOp, ttng::AsyncTMAScatterOp,
ttng::AsyncTMACopyGlobalToLocalOp>(op))
*minWarps = 2;
else if (isa<ttng::TMEMLoadOp, ttng::TMEMStoreOp, ttng::TMEMAllocOp>(op))
*minWarps = 4;
});
}
bool changed;
do {
changed = false;
int32_t curTotalNumWarps = std::accumulate(
partitionNumWarps.begin(), partitionNumWarps.end(), defaultNumWarps);
for (auto [minWarps, numWarps, tensorRegs] :
llvm::zip(minWarpsForPartition, partitionNumWarps, maxTensorRegs)) {
if (numWarps <= minWarps)
continue;
unsigned reqRegsPerThread = tensorRegs / threadsPerWarp / (numWarps / 2);
unsigned nextTotalNumWarps = curTotalNumWarps - (numWarps / 2);
unsigned nextRegsPerThread =
nTotalRegs / threadsPerWarp / nextTotalNumWarps;
if (reqRegsPerThread <= nextRegsPerThread) {
numWarps /= 2;
changed = true;
break;
}
}
} while (changed);
SmallVector<int32_t> estRegUsage(partitionNumWarps.size());
for (auto [partition, newNumWarps, prevNumWarps, tensorRegs, estRegs] :
llvm::zip(wsOp.getPartitionRegions(), partitionNumWarps,
wsOp.getPartitionNumWarps(), maxTensorRegs, estRegUsage)) {
estRegs = tensorRegs ? 88 : 24;
if (newNumWarps == prevNumWarps || !tensorRegs)
continue;
if (failed(relayoutWarps(axisInfo, partition, prevNumWarps, newNumWarps,
runPipeline)))
return failure();
}
wsOp.setRequestedRegisters(estRegUsage);
wsOp.setPartitionNumWarps(partitionNumWarps);
return success();
}
namespace mlir::triton::gpu {
#define GEN_PASS_DEF_TRITONGPUOPTIMIZEPARTITIONWARPS
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
}
namespace {
struct OptimizePartitionWarps
: triton::gpu::impl::TritonGPUOptimizePartitionWarpsBase<
OptimizePartitionWarps> {
using TritonGPUOptimizePartitionWarpsBase::
TritonGPUOptimizePartitionWarpsBase;
void runOnOperation() override;
};
}
void OptimizePartitionWarps::runOnOperation() {
ModuleAxisInfoAnalysis axisInfo(getOperation());
auto runPipelineFn = [&](OpPassManager &pm, ModuleOp container) {
getOperation().push_back(container);
auto remove = llvm::make_scope_exit([&] { container->remove(); });
return runPipeline(pm, container);
};
WalkResult result = getOperation().walk([&](WarpSpecializeOp wsOp) {
if (failed(optimizePartitionNumWarps(axisInfo, wsOp, runPipelineFn)))
return WalkResult::interrupt();
return WalkResult::skip();
});
if (result.wasInterrupted())
return signalPassFailure();
}