#include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/SymbolTable.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Hashing.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
using namespace mlir;
static constexpr unsigned maxFreeID = 1 << 20;
static StringAttr renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
spirv::ModuleOp module) {
SmallString<64> newSymName(oldSymName);
newSymName.push_back('_');
MLIRContext *ctx = module->getContext();
while (lastUsedID < maxFreeID) {
auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
if (!SymbolTable::lookupSymbolIn(module, possible))
return possible;
}
return StringAttr::get(ctx, newSymName);
}
static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
spirv::ModuleOp target,
spirv::ModuleOp source,
unsigned &lastUsedID) {
if (!SymbolTable::lookupSymbolIn(source, op.getName()))
return success();
StringRef oldSymName = op.getName();
StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
return op.emitError("unable to update all symbol uses for ")
<< oldSymName << " to " << newSymName;
SymbolTable::setSymbolName(op, newSymName);
return success();
}
static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
auto range =
llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) {
return attr.getName() != SymbolTable::getSymbolAttrName();
});
return llvm::hash_combine(
symbolOp->getName(),
llvm::hash_combine_range(range.begin(), range.end()));
}
namespace mlir {
namespace spirv {
OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules,
OpBuilder &combinedModuleBuilder,
SymbolRenameListener symRenameListener) {
if (inputModules.empty())
return nullptr;
spirv::ModuleOp firstModule = inputModules.front();
auto addressingModel = firstModule.getAddressingModel();
auto memoryModel = firstModule.getMemoryModel();
auto vceTriple = firstModule.getVceTriple();
for (auto module : inputModules) {
if (module.getAddressingModel() != addressingModel ||
module.getMemoryModel() != memoryModel ||
module.getVceTriple() != vceTriple) {
module.emitError("input modules differ in addressing model, memory "
"model, and/or VCE triple");
return nullptr;
}
}
auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
firstModule.getLoc(), addressingModel, memoryModel, vceTriple);
combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
llvm::StringMap<spirv::ModuleOp> symNameToModuleMap;
unsigned lastUsedID = 0;
for (auto inputModule : inputModules) {
OwningOpRef<spirv::ModuleOp> moduleClone = inputModule.clone();
for (auto &op : *combinedModule.getBody()) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
StringRef oldSymName = symbolOp.getName();
if (!isa<FuncOp>(op) &&
failed(updateSymbolAndAllUses(symbolOp, combinedModule, *moduleClone,
lastUsedID)))
return nullptr;
StringRef newSymName = symbolOp.getName();
if (symRenameListener && oldSymName != newSymName) {
spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
if (!originalModule) {
inputModule.emitError(
"unable to find original spirv::ModuleOp for symbol ")
<< oldSymName;
return nullptr;
}
symRenameListener(originalModule, oldSymName, newSymName);
symNameToModuleMap.erase(oldSymName);
symNameToModuleMap[newSymName] = originalModule;
}
}
for (auto &op : *moduleClone->getBody()) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
StringRef oldSymName = symbolOp.getName();
if (failed(updateSymbolAndAllUses(symbolOp, *moduleClone, combinedModule,
lastUsedID)))
return nullptr;
StringRef newSymName = symbolOp.getName();
if (symRenameListener) {
if (oldSymName != newSymName)
symRenameListener(inputModule, oldSymName, newSymName);
auto emplaceResult =
symNameToModuleMap.try_emplace(newSymName, inputModule);
if (!emplaceResult.second) {
inputModule.emitError("did not expect to find an entry for symbol ")
<< symbolOp.getName();
return nullptr;
}
}
}
for (auto &op : *moduleClone->getBody())
combinedModuleBuilder.insert(op.clone());
}
DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
SmallVector<SymbolOpInterface, 0> eraseList;
for (auto &op : *combinedModule.getBody()) {
SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
if (op.getNumOperands() != 0 || op.getNumResults() != 0)
continue;
if (isa<FuncOp>(op))
continue;
auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp);
if (result.second)
continue;
SymbolOpInterface replacementSymOp = result.first->second;
if (failed(SymbolTable::replaceAllSymbolUses(
symbolOp, replacementSymOp.getNameAttr(), combinedModule))) {
symbolOp.emitError("unable to update all symbol uses for ")
<< symbolOp.getName() << " to " << replacementSymOp.getName();
return nullptr;
}
eraseList.push_back(symbolOp);
}
for (auto symbolOp : eraseList)
symbolOp.erase();
return combinedModule;
}
}
}