#include "mlir/Dialect/Transform/IR/Utils.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/FunctionInterfaces.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
#define DEBUG_TYPE "transform-dialect-utils"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
static bool canMergeInto(FunctionOpInterface func1, FunctionOpInterface func2) {
return func1.isExternal() && (func2.isPublic() || func2.isExternal());
}
static InFlightDiagnostic mergeInto(FunctionOpInterface func1,
FunctionOpInterface func2) {
assert(canMergeInto(func1, func2));
assert(func1->getParentOp() == func2->getParentOp() &&
"expected func1 and func2 to be in the same parent op");
if (func1.getFunctionType() != func2.getFunctionType()) {
return func1.emitError()
<< "external definition has a mismatching signature ("
<< func2.getFunctionType() << ")";
}
MLIRContext *context = func1->getContext();
auto *td = context->getLoadedDialect<transform::TransformDialect>();
StringAttr consumedName = td->getConsumedAttrName();
StringAttr readOnlyName = td->getReadOnlyAttrName();
for (unsigned i = 0, e = func1.getNumArguments(); i < e; ++i) {
bool isExternalConsumed = func2.getArgAttr(i, consumedName) != nullptr;
bool isExternalReadonly = func2.getArgAttr(i, readOnlyName) != nullptr;
bool isConsumed = func1.getArgAttr(i, consumedName) != nullptr;
bool isReadonly = func1.getArgAttr(i, readOnlyName) != nullptr;
if (!isExternalConsumed && !isExternalReadonly) {
if (isConsumed)
func2.setArgAttr(i, consumedName, UnitAttr::get(context));
else if (isReadonly)
func2.setArgAttr(i, readOnlyName, UnitAttr::get(context));
continue;
}
if ((isExternalConsumed && !isConsumed) ||
(isExternalReadonly && !isReadonly)) {
return func1.emitError()
<< "external definition has mismatching consumption "
"annotations for argument #"
<< i;
}
}
assert(func1.isExternal());
func1->erase();
return InFlightDiagnostic();
}
InFlightDiagnostic
transform::detail::mergeSymbolsInto(Operation *target,
OwningOpRef<Operation *> other) {
assert(target->hasTrait<OpTrait::SymbolTable>() &&
"requires target to implement the 'SymbolTable' trait");
assert(other->hasTrait<OpTrait::SymbolTable>() &&
"requires target to implement the 'SymbolTable' trait");
SymbolTable targetSymbolTable(target);
SymbolTable otherSymbolTable(*other);
LLVM_DEBUG(DBGS() << "renaming private symbols to resolve conflicts:\n");
for (auto &&[symbolTable, otherSymbolTable] : llvm::zip(
SmallVector<SymbolTable *, 2>{&targetSymbolTable, &otherSymbolTable},
SmallVector<SymbolTable *, 2>{&otherSymbolTable,
&targetSymbolTable})) {
Operation *symbolTableOp = symbolTable->getOp();
for (Operation &op : symbolTableOp->getRegion(0).front()) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;
StringAttr name = symbolOp.getNameAttr();
LLVM_DEBUG(DBGS() << " found @" << name.getValue() << "\n");
auto collidingOp =
cast_or_null<SymbolOpInterface>(otherSymbolTable->lookup(name));
if (!collidingOp)
continue;
LLVM_DEBUG(DBGS() << " collision found for @" << name.getValue());
if (auto funcOp = dyn_cast<FunctionOpInterface>(op),
collidingFuncOp =
dyn_cast<FunctionOpInterface>(collidingOp.getOperation());
funcOp && collidingFuncOp) {
if (canMergeInto(funcOp, collidingFuncOp) ||
canMergeInto(collidingFuncOp, funcOp)) {
LLVM_DEBUG(llvm::dbgs() << " but both ops are functions and "
"will be merged\n");
continue;
}
LLVM_DEBUG(llvm::dbgs() << " and both ops are function definitions");
}
auto renameToUnique =
[&](SymbolOpInterface op, SymbolOpInterface otherOp,
SymbolTable &symbolTable,
SymbolTable &otherSymbolTable) -> InFlightDiagnostic {
LLVM_DEBUG(llvm::dbgs() << ", renaming\n");
FailureOr<StringAttr> maybeNewName =
symbolTable.renameToUnique(op, {&otherSymbolTable});
if (failed(maybeNewName)) {
InFlightDiagnostic diag = op->emitError("failed to rename symbol");
diag.attachNote(otherOp->getLoc())
<< "attempted renaming due to collision with this op";
return diag;
}
LLVM_DEBUG(DBGS() << " renamed to @" << maybeNewName->getValue()
<< "\n");
return InFlightDiagnostic();
};
if (symbolOp.isPrivate()) {
InFlightDiagnostic diag = renameToUnique(
symbolOp, collidingOp, *symbolTable, *otherSymbolTable);
if (failed(diag))
return diag;
continue;
}
if (collidingOp.isPrivate()) {
InFlightDiagnostic diag = renameToUnique(
collidingOp, symbolOp, *otherSymbolTable, *symbolTable);
if (failed(diag))
return diag;
continue;
}
LLVM_DEBUG(llvm::dbgs() << ", emitting error\n");
InFlightDiagnostic diag = symbolOp.emitError()
<< "doubly defined symbol @" << name.getValue();
diag.attachNote(collidingOp->getLoc()) << "previously defined here";
return diag;
}
}
for (auto *op : SmallVector<Operation *>{target, *other}) {
if (failed(mlir::verify(op)))
return op->emitError() << "failed to verify input op after renaming";
}
LLVM_DEBUG(DBGS() << "moving all symbols into target\n");
{
SmallVector<SymbolOpInterface> opsToMove;
for (Operation &op : other->getRegion(0).front()) {
if (auto symbol = dyn_cast<SymbolOpInterface>(op))
opsToMove.push_back(symbol);
}
for (SymbolOpInterface op : opsToMove) {
auto collidingOp = cast_or_null<SymbolOpInterface>(
targetSymbolTable.lookup(op.getNameAttr()));
LLVM_DEBUG(DBGS() << " moving @" << op.getName());
op->moveBefore(&target->getRegion(0).front(),
target->getRegion(0).front().end());
if (!collidingOp) {
LLVM_DEBUG(llvm::dbgs() << " without collision\n");
continue;
}
auto funcOp = cast<FunctionOpInterface>(op.getOperation());
auto collidingFuncOp =
cast<FunctionOpInterface>(collidingOp.getOperation());
if (!canMergeInto(funcOp, collidingFuncOp)) {
std::swap(funcOp, collidingFuncOp);
}
assert(canMergeInto(funcOp, collidingFuncOp));
LLVM_DEBUG(llvm::dbgs() << " with collision, trying to keep op at "
<< collidingFuncOp.getLoc() << ":\n"
<< collidingFuncOp << "\n");
targetSymbolTable.remove(funcOp);
targetSymbolTable.insert(collidingFuncOp);
assert(targetSymbolTable.lookup(funcOp.getName()) == collidingFuncOp);
{
InFlightDiagnostic diag = mergeInto(funcOp, collidingFuncOp);
if (failed(diag))
return diag;
}
}
}
if (failed(mlir::verify(target)))
return target->emitError()
<< "failed to verify target op after merging symbols";
LLVM_DEBUG(DBGS() << "done merging ops\n");
return InFlightDiagnostic();
}