#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
#include "../PassDetail.h"
#include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
#include "mlir/Conversion/GPUToSPIRV/GPUToSPIRV.h"
#include "mlir/Conversion/MemRefToSPIRV/MemRefToSPIRV.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
using namespace mlir;
namespace {
struct GPUToSPIRVPass : public ConvertGPUToSPIRVBase<GPUToSPIRVPass> {
void runOnOperation() override;
};
}
void GPUToSPIRVPass::runOnOperation() {
MLIRContext *context = &getContext();
ModuleOp module = getOperation();
SmallVector<Operation *, 1> kernelModules;
OpBuilder builder(context);
module.walk([&builder, &kernelModules](gpu::GPUModuleOp moduleOp) {
builder.setInsertionPoint(moduleOp.getOperation());
kernelModules.push_back(builder.clone(*moduleOp.getOperation()));
});
auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
std::unique_ptr<ConversionTarget> target =
SPIRVConversionTarget::get(targetAttr);
SPIRVTypeConverter typeConverter(targetAttr);
RewritePatternSet patterns(context);
populateGPUToSPIRVPatterns(typeConverter, patterns);
mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
populateMemRefToSPIRVPatterns(typeConverter, patterns);
populateFuncToSPIRVPatterns(typeConverter, patterns);
if (failed(applyFullConversion(kernelModules, *target, std::move(patterns))))
return signalPassFailure();
}
std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertGPUToSPIRVPass() {
return std::make_unique<GPUToSPIRVPass>();
}