#pragma once
#include "mlir/IR/Diagnostics.h"
#include "mlir/Pass/Pass.h"
#include "triton/Analysis/AxisInfo.h"
using namespace mlir;
using namespace mlir::triton;
namespace mlir::test {
struct TestAxisInfoPass
: public PassWrapper<TestAxisInfoPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
StringRef getArgument() const override { return "test-print-alignment"; }
StringRef getDescription() const final {
return "print the result of the alignment analysis pass";
}
void runOnOperation() override {
Operation *operation = this->getOperation();
ModuleOp moduleOp = cast<ModuleOp>(operation);
auto moduleAxisInfoAnalysis = getAnalysis(moduleOp);
moduleOp.walk([&](FuncOp funcOp) {
funcOp.walk([&](Operation *op) {
if (op->getNumResults() < 1)
return;
for (Value result : op->getResults()) {
InFlightDiagnostic diag = mlir::emitRemark(op->getLoc());
diag << result;
diag << " => ";
auto *axisInfo = moduleAxisInfoAnalysis.getAxisInfo(result);
if (axisInfo) {
std::string str;
llvm::raw_string_ostream os(str);
axisInfo->print(os);
diag << str;
}
}
});
});
}
protected:
virtual ModuleAxisInfoAnalysis getAnalysis(ModuleOp moduleOp) const {
return ModuleAxisInfoAnalysis(moduleOp);
}
};
}