#include "mlir/Pass/Pass.h"
#include "triton/Analysis/Allocation.h"
using namespace mlir;
namespace {
unsigned getScratchSize128(Operation *) { return 128; }
enum class GetScratchSizeFunction {
None,
ValidConstant,
};
struct TestAllocationPass
: public PassWrapper<TestAllocationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
TestAllocationPass() = default;
TestAllocationPass(const TestAllocationPass &other)
: PassWrapper<TestAllocationPass, OperationPass<ModuleOp>>(other) {}
StringRef getArgument() const final { return "test-print-allocation"; }
StringRef getDescription() const final {
return "print the result of the allocation pass";
}
ModuleAllocation getModuleAllocation() {
switch (getScratchSizeFunction) {
case GetScratchSizeFunction::None:
return {getOperation()};
case GetScratchSizeFunction::ValidConstant:
return {getOperation(), getScratchSize128};
}
llvm_unreachable("Unhandled case");
}
void runOnOperation() override {
ModuleOp moduleOp = getOperation();
ModuleAllocation moduleAllocation = getModuleAllocation();
moduleOp.walk([&](triton::FuncOp funcOp) {
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
mlir::emitRemark(funcOp.getLoc(), opName);
auto *allocation = moduleAllocation.getFuncData(funcOp);
funcOp.walk([&](Operation *op) {
auto scratchBufferId = allocation->getBufferId(op);
if (scratchBufferId != Allocation::InvalidBufferId) {
size_t offset = allocation->getOffset(scratchBufferId);
size_t size = allocation->getAllocatedSize(scratchBufferId);
mlir::emitRemark(op->getLoc())
<< (allocation->isVirtualBuffer(scratchBufferId) ? "virtual"
: "scratch")
<< " offset = " << offset << ", size = " << size;
}
if (op->getNumResults() < 1)
return;
for (Value result : op->getResults()) {
auto bufferId = allocation->getBufferId(result);
if (bufferId != Allocation::InvalidBufferId) {
size_t offset = allocation->getOffset(bufferId);
size_t size = allocation->getAllocatedSize(bufferId);
mlir::emitRemark(op->getLoc())
<< "offset = " << offset << ", size = " << size;
}
}
});
mlir::emitRemark(funcOp.getLoc())
<< "size = " << allocation->getSharedMemorySize();
});
}
Option<GetScratchSizeFunction> getScratchSizeFunction{
*this, "get-scratch-size-function",
llvm::cl::desc("Custom scratch size function to use"),
llvm::cl::init(GetScratchSizeFunction::None),
llvm::cl::values(
clEnumValN(GetScratchSizeFunction::None, "None", "None (default)"),
clEnumValN(GetScratchSizeFunction::ValidConstant, "ValidConstant",
"ValidConstant"))};
};
}
namespace mlir {
namespace test {
void registerTestAllocationPass() { PassRegistration<TestAllocationPass>(); }
}
}