#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/Transforms/MemoryPromotion.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/IR/Attributes.h"
#include "mlir/Pass/Pass.h"
using namespace mlir;
namespace {
struct TestGpuMemoryPromotionPass
: public PassWrapper<TestGpuMemoryPromotionPass,
OperationPass<gpu::GPUFuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestGpuMemoryPromotionPass)
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<affine::AffineDialect, memref::MemRefDialect,
scf::SCFDialect>();
}
StringRef getArgument() const final { return "test-gpu-memory-promotion"; }
StringRef getDescription() const final {
return "Promotes the annotated arguments of gpu.func to workgroup memory.";
}
void runOnOperation() override {
gpu::GPUFuncOp op = getOperation();
for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
if (op.getArgAttrOfType<UnitAttr>(i, "gpu.test_promote_workgroup"))
promoteToWorkgroupMemory(op, i);
}
}
};
}
namespace mlir {
void registerTestGpuMemoryPromotionPass() {
PassRegistration<TestGpuMemoryPromotionPass>();
}
}