#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Target/LLVM/NVVM/Target.h"
#include "llvm/Support/Regex.h"
namespace mlir {
#define GEN_PASS_DEF_GPUNVVMATTACHTARGET
#include "mlir/Dialect/GPU/Transforms/Passes.h.inc"
}
using namespace mlir;
using namespace mlir::NVVM;
namespace {
struct NVVMAttachTarget
: public impl::GpuNVVMAttachTargetBase<NVVMAttachTarget> {
using Base::Base;
DictionaryAttr getFlags(OpBuilder &builder) const;
void runOnOperation() override;
void getDependentDialects(DialectRegistry ®istry) const override {
registry.insert<NVVM::NVVMDialect>();
}
};
}
DictionaryAttr NVVMAttachTarget::getFlags(OpBuilder &builder) const {
UnitAttr unitAttr = builder.getUnitAttr();
SmallVector<NamedAttribute, 2> flags;
auto addFlag = [&](StringRef flag) {
flags.push_back(builder.getNamedAttr(flag, unitAttr));
};
if (fastFlag)
addFlag("fast");
if (ftzFlag)
addFlag("ftz");
if (!flags.empty())
return builder.getDictionaryAttr(flags);
return nullptr;
}
void NVVMAttachTarget::runOnOperation() {
OpBuilder builder(&getContext());
ArrayRef<std::string> libs(linkLibs);
SmallVector<StringRef> filesToLink(libs.begin(), libs.end());
auto target = builder.getAttr<NVVMTargetAttr>(
optLevel, triple, chip, features, getFlags(builder),
filesToLink.empty() ? nullptr : builder.getStrArrayAttr(filesToLink));
llvm::Regex matcher(moduleMatcher);
for (Region ®ion : getOperation()->getRegions())
for (Block &block : region.getBlocks())
for (auto module : block.getOps<gpu::GPUModuleOp>()) {
if (!moduleMatcher.empty() && !matcher.match(module.getName()))
continue;
SmallVector<Attribute> targets;
if (std::optional<ArrayAttr> attrs = module.getTargets())
targets.append(attrs->getValue().begin(), attrs->getValue().end());
targets.push_back(target);
targets.erase(llvm::unique(targets), targets.end());
module.setTargetsAttr(builder.getArrayAttr(targets));
}
}