#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringSwitch.h"
using namespace mlir;
static bool isPotentiallyUnknownSymbolTable(Operation *op) {
return op->getNumRegions() == 1 && !op->getDialect();
}
static StringAttr getNameIfSymbol(Operation *op) {
return op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
}
static StringAttr getNameIfSymbol(Operation *op, StringAttr symbolAttrNameId) {
return op->getAttrOfType<StringAttr>(symbolAttrNameId);
}
static LogicalResult
collectValidReferencesFor(Operation *symbol, StringAttr symbolName,
Operation *within,
SmallVectorImpl<SymbolRefAttr> &results) {
assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor");
MLIRContext *ctx = symbol->getContext();
auto leafRef = FlatSymbolRefAttr::get(symbolName);
results.push_back(leafRef);
Operation *symbolTableOp = symbol->getParentOp();
if (within == symbolTableOp)
return success();
SmallVector<FlatSymbolRefAttr, 1> nestedRefs(1, leafRef);
StringAttr symbolNameId =
StringAttr::get(ctx, SymbolTable::getSymbolAttrName());
do {
if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
return failure();
StringAttr symbolTableName = getNameIfSymbol(symbolTableOp, symbolNameId);
if (!symbolTableName)
return failure();
results.push_back(SymbolRefAttr::get(symbolTableName, nestedRefs));
symbolTableOp = symbolTableOp->getParentOp();
if (symbolTableOp == within)
break;
nestedRefs.insert(nestedRefs.begin(),
FlatSymbolRefAttr::get(symbolTableName));
} while (true);
return success();
}
static Optional<WalkResult>
walkSymbolTable(MutableArrayRef<Region> regions,
function_ref<Optional<WalkResult>(Operation *)> callback) {
SmallVector<Region *, 1> worklist(llvm::make_pointer_range(regions));
while (!worklist.empty()) {
for (Operation &op : worklist.pop_back_val()->getOps()) {
Optional<WalkResult> result = callback(&op);
if (result != WalkResult::advance())
return result;
if (!op.hasTrait<OpTrait::SymbolTable>()) {
for (Region ®ion : op.getRegions())
worklist.push_back(®ion);
}
}
}
return WalkResult::advance();
}
static Optional<WalkResult>
walkSymbolTable(Operation *op,
function_ref<Optional<WalkResult>(Operation *)> callback) {
Optional<WalkResult> result = callback(op);
if (result != WalkResult::advance() || op->hasTrait<OpTrait::SymbolTable>())
return result;
return walkSymbolTable(op->getRegions(), callback);
}
SymbolTable::SymbolTable(Operation *symbolTableOp)
: symbolTableOp(symbolTableOp) {
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>() &&
"expected operation to have SymbolTable trait");
assert(symbolTableOp->getNumRegions() == 1 &&
"expected operation to have a single region");
assert(llvm::hasSingleElement(symbolTableOp->getRegion(0)) &&
"expected operation to have a single block");
StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
SymbolTable::getSymbolAttrName());
for (auto &op : symbolTableOp->getRegion(0).front()) {
StringAttr name = getNameIfSymbol(&op, symbolNameId);
if (!name)
continue;
auto inserted = symbolTable.insert({name, &op});
(void)inserted;
assert(inserted.second &&
"expected region to contain uniquely named symbol operations");
}
}
Operation *SymbolTable::lookup(StringRef name) const {
return lookup(StringAttr::get(symbolTableOp->getContext(), name));
}
Operation *SymbolTable::lookup(StringAttr name) const {
return symbolTable.lookup(name);
}
void SymbolTable::erase(Operation *symbol) {
StringAttr name = getNameIfSymbol(symbol);
assert(name && "expected valid 'name' attribute");
assert(symbol->getParentOp() == symbolTableOp &&
"expected this operation to be inside of the operation with this "
"SymbolTable");
auto it = symbolTable.find(name);
if (it != symbolTable.end() && it->second == symbol) {
symbolTable.erase(it);
symbol->erase();
}
}
StringAttr SymbolTable::insert(Operation *symbol, Block::iterator insertPt) {
if (!symbol->getParentOp()) {
auto &body = symbolTableOp->getRegion(0).front();
if (insertPt == Block::iterator()) {
insertPt = Block::iterator(body.end());
} else {
assert((insertPt == body.end() ||
insertPt->getParentOp() == symbolTableOp) &&
"expected insertPt to be in the associated module operation");
}
if (insertPt == Block::iterator(body.end()) && !body.empty() &&
std::prev(body.end())->hasTrait<OpTrait::IsTerminator>())
insertPt = std::prev(body.end());
body.getOperations().insert(insertPt, symbol);
}
assert(symbol->getParentOp() == symbolTableOp &&
"symbol is already inserted in another op");
StringAttr name = getSymbolName(symbol);
if (symbolTable.insert({name, symbol}).second)
return name;
if (symbolTable.lookup(name) == symbol)
return name;
SmallString<128> nameBuffer(name.getValue());
unsigned originalLength = nameBuffer.size();
MLIRContext *context = symbol->getContext();
do {
nameBuffer.resize(originalLength);
nameBuffer += '_';
nameBuffer += std::to_string(uniquingCounter++);
} while (!symbolTable.insert({StringAttr::get(context, nameBuffer), symbol})
.second);
setSymbolName(symbol, nameBuffer);
return getSymbolName(symbol);
}
StringAttr SymbolTable::getSymbolName(Operation *symbol) {
StringAttr name = getNameIfSymbol(symbol);
assert(name && "expected valid symbol name");
return name;
}
void SymbolTable::setSymbolName(Operation *symbol, StringAttr name) {
symbol->setAttr(getSymbolAttrName(), name);
}
SymbolTable::Visibility SymbolTable::getSymbolVisibility(Operation *symbol) {
StringAttr vis = symbol->getAttrOfType<StringAttr>(getVisibilityAttrName());
if (!vis)
return Visibility::Public;
return StringSwitch<Visibility>(vis.getValue())
.Case("private", Visibility::Private)
.Case("nested", Visibility::Nested)
.Case("public", Visibility::Public);
}
void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
MLIRContext *ctx = symbol->getContext();
if (vis == Visibility::Public) {
symbol->removeAttr(StringAttr::get(ctx, getVisibilityAttrName()));
return;
}
assert((vis == Visibility::Private || vis == Visibility::Nested) &&
"unknown symbol visibility kind");
StringRef visName = vis == Visibility::Private ? "private" : "nested";
symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName));
}
Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
assert(from && "expected valid operation");
if (isPotentiallyUnknownSymbolTable(from))
return nullptr;
while (!from->hasTrait<OpTrait::SymbolTable>()) {
from = from->getParentOp();
if (!from || isPotentiallyUnknownSymbolTable(from))
return nullptr;
}
return from;
}
void SymbolTable::walkSymbolTables(
Operation *op, bool allSymUsesVisible,
function_ref<void(Operation *, bool)> callback) {
bool isSymbolTable = op->hasTrait<OpTrait::SymbolTable>();
if (isSymbolTable) {
SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
allSymUsesVisible |= !symbol || symbol.isPrivate();
} else {
allSymUsesVisible = true;
}
for (Region ®ion : op->getRegions())
for (Block &block : region)
for (Operation &nestedOp : block)
walkSymbolTables(&nestedOp, allSymUsesVisible, callback);
if (isSymbolTable)
callback(op, allSymUsesVisible);
}
Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
StringAttr symbol) {
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
Region ®ion = symbolTableOp->getRegion(0);
if (region.empty())
return nullptr;
StringAttr symbolNameId = StringAttr::get(symbolTableOp->getContext(),
SymbolTable::getSymbolAttrName());
for (auto &op : region.front())
if (getNameIfSymbol(&op, symbolNameId) == symbol)
return &op;
return nullptr;
}
Operation *SymbolTable::lookupSymbolIn(Operation *symbolTableOp,
SymbolRefAttr symbol) {
SmallVector<Operation *, 4> resolvedSymbols;
if (failed(lookupSymbolIn(symbolTableOp, symbol, resolvedSymbols)))
return nullptr;
return resolvedSymbols.back();
}
static LogicalResult lookupSymbolInImpl(
Operation *symbolTableOp, SymbolRefAttr symbol,
SmallVectorImpl<Operation *> &symbols,
function_ref<Operation *(Operation *, StringAttr)> lookupSymbolFn) {
assert(symbolTableOp->hasTrait<OpTrait::SymbolTable>());
symbolTableOp = lookupSymbolFn(symbolTableOp, symbol.getRootReference());
if (!symbolTableOp)
return failure();
symbols.push_back(symbolTableOp);
ArrayRef<FlatSymbolRefAttr> nestedRefs = symbol.getNestedReferences();
if (nestedRefs.empty())
return success();
if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>())
return failure();
for (FlatSymbolRefAttr ref : nestedRefs.drop_back()) {
symbolTableOp = lookupSymbolFn(symbolTableOp, ref.getAttr());
if (!symbolTableOp || !symbolTableOp->hasTrait<OpTrait::SymbolTable>())
return failure();
symbols.push_back(symbolTableOp);
}
symbols.push_back(lookupSymbolFn(symbolTableOp, symbol.getLeafReference()));
return success(symbols.back());
}
LogicalResult
SymbolTable::lookupSymbolIn(Operation *symbolTableOp, SymbolRefAttr symbol,
SmallVectorImpl<Operation *> &symbols) {
auto lookupFn = [](Operation *symbolTableOp, StringAttr symbol) {
return lookupSymbolIn(symbolTableOp, symbol);
};
return lookupSymbolInImpl(symbolTableOp, symbol, symbols, lookupFn);
}
Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
StringAttr symbol) {
Operation *symbolTableOp = getNearestSymbolTable(from);
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
}
Operation *SymbolTable::lookupNearestSymbolFrom(Operation *from,
SymbolRefAttr symbol) {
Operation *symbolTableOp = getNearestSymbolTable(from);
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
}
raw_ostream &mlir::operator<<(raw_ostream &os,
SymbolTable::Visibility visibility) {
switch (visibility) {
case SymbolTable::Visibility::Public:
return os << "public";
case SymbolTable::Visibility::Private:
return os << "private";
case SymbolTable::Visibility::Nested:
return os << "nested";
}
llvm_unreachable("Unexpected visibility");
}
LogicalResult detail::verifySymbolTable(Operation *op) {
if (op->getNumRegions() != 1)
return op->emitOpError()
<< "Operations with a 'SymbolTable' must have exactly one region";
if (!llvm::hasSingleElement(op->getRegion(0)))
return op->emitOpError()
<< "Operations with a 'SymbolTable' must have exactly one block";
DenseMap<Attribute, Location> nameToOrigLoc;
for (auto &block : op->getRegion(0)) {
for (auto &op : block) {
auto nameAttr =
op.getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName());
if (!nameAttr)
continue;
auto it = nameToOrigLoc.try_emplace(nameAttr, op.getLoc());
if (!it.second)
return op.emitError()
.append("redefinition of symbol named '", nameAttr.getValue(), "'")
.attachNote(it.first->second)
.append("see existing symbol definition here");
}
}
SymbolTableCollection symbolTable;
auto verifySymbolUserFn = [&](Operation *op) -> Optional<WalkResult> {
if (SymbolUserOpInterface user = dyn_cast<SymbolUserOpInterface>(op))
return WalkResult(user.verifySymbolUses(symbolTable));
return WalkResult::advance();
};
Optional<WalkResult> result =
walkSymbolTable(op->getRegions(), verifySymbolUserFn);
return success(result && !result->wasInterrupted());
}
LogicalResult detail::verifySymbol(Operation *op) {
if (!op->getAttrOfType<StringAttr>(mlir::SymbolTable::getSymbolAttrName()))
return op->emitOpError() << "requires string attribute '"
<< mlir::SymbolTable::getSymbolAttrName() << "'";
if (Attribute vis = op->getAttr(mlir::SymbolTable::getVisibilityAttrName())) {
StringAttr visStrAttr = vis.dyn_cast<StringAttr>();
if (!visStrAttr)
return op->emitOpError() << "requires visibility attribute '"
<< mlir::SymbolTable::getVisibilityAttrName()
<< "' to be a string attribute, but got " << vis;
if (!llvm::is_contained(ArrayRef<StringRef>{"public", "private", "nested"},
visStrAttr.getValue()))
return op->emitOpError()
<< "visibility expected to be one of [\"public\", \"private\", "
"\"nested\"], but got "
<< visStrAttr;
}
return success();
}
static WalkResult
walkSymbolRefs(Operation *op,
function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
DictionaryAttr attrDict = op->getAttrDictionary();
if (attrDict.empty())
return WalkResult::advance();
struct WorklistItem {
SubElementAttrInterface container;
SmallVector<Attribute> immediateSubElements;
explicit WorklistItem(SubElementAttrInterface container) {
SmallVector<Attribute> subElements;
container.walkImmediateSubElements(
[&](Attribute attr) { subElements.push_back(attr); }, [](Type) {});
immediateSubElements = std::move(subElements);
}
};
SmallVector<WorklistItem, 1> attrWorklist(1, WorklistItem(attrDict));
SmallVector<int, 1> curAccessChain(1, -1);
auto processAttrs = [&](int &index,
WorklistItem &worklistItem) -> WalkResult {
for (Attribute attr :
llvm::drop_begin(worklistItem.immediateSubElements, index)) {
if (auto symbolRef = attr.dyn_cast<SymbolRefAttr>()) {
if (callback({op, symbolRef}).wasInterrupted())
return WalkResult::interrupt();
} else if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
attrWorklist.emplace_back(interface);
curAccessChain.push_back(-1);
return WalkResult::advance();
}
++index;
}
attrWorklist.pop_back();
curAccessChain.pop_back();
return WalkResult::advance();
};
WalkResult result = WalkResult::advance();
do {
WorklistItem &item = attrWorklist.back();
int &index = curAccessChain.back();
++index;
result = processAttrs(index, item);
} while (!attrWorklist.empty() && !result.wasInterrupted());
return result;
}
static Optional<WalkResult>
walkSymbolUses(MutableArrayRef<Region> regions,
function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
return walkSymbolTable(regions, [&](Operation *op) -> Optional<WalkResult> {
if (isPotentiallyUnknownSymbolTable(op))
return llvm::None;
return walkSymbolRefs(op, callback);
});
}
static Optional<WalkResult>
walkSymbolUses(Operation *from,
function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
if (isPotentiallyUnknownSymbolTable(from))
return llvm::None;
if (walkSymbolRefs(from, callback).wasInterrupted())
return WalkResult::interrupt();
if (!from->hasTrait<OpTrait::SymbolTable>())
return walkSymbolUses(from->getRegions(), callback);
return WalkResult::advance();
}
namespace {
struct SymbolScope {
template <typename CallbackT,
typename std::enable_if_t<!std::is_same<
typename llvm::function_traits<CallbackT>::result_t,
void>::value> * = nullptr>
Optional<WalkResult> walk(CallbackT cback) {
if (Region *region = limit.dyn_cast<Region *>())
return walkSymbolUses(*region, cback);
return walkSymbolUses(limit.get<Operation *>(), cback);
}
template <typename CallbackT,
typename std::enable_if_t<std::is_same<
typename llvm::function_traits<CallbackT>::result_t,
void>::value> * = nullptr>
Optional<WalkResult> walk(CallbackT cback) {
return walk([=](SymbolTable::SymbolUse use) {
return cback(use), WalkResult::advance();
});
}
template <typename CallbackT>
Optional<WalkResult> walkSymbolTable(CallbackT &&cback) {
if (Region *region = limit.dyn_cast<Region *>())
return ::walkSymbolTable(*region, cback);
return ::walkSymbolTable(limit.get<Operation *>(), cback);
}
SymbolRefAttr symbol;
llvm::PointerUnion<Operation *, Region *> limit;
};
}
static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
Operation *limit) {
StringAttr symName = SymbolTable::getSymbolName(symbol);
assert(!symbol->hasTrait<OpTrait::SymbolTable>() || symbol != limit);
SetVector<Operation *, SmallVector<Operation *, 4>,
SmallPtrSet<Operation *, 4>>
limitAncestors;
Operation *limitAncestor = limit;
do {
if (limitAncestor == symbol) {
if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) ==
symbol->getParentOp())
return {{SymbolRefAttr::get(symName), limit}};
return {};
}
limitAncestors.insert(limitAncestor);
} while ((limitAncestor = limitAncestor->getParentOp()));
Operation *commonAncestor = symbol->getParentOp();
do {
if (limitAncestors.count(commonAncestor))
break;
} while ((commonAncestor = commonAncestor->getParentOp()));
assert(commonAncestor && "'limit' and 'symbol' have no common ancestor");
SmallVector<SymbolRefAttr, 2> references;
bool collectedAllReferences = succeeded(
collectValidReferencesFor(symbol, symName, commonAncestor, references));
if (commonAncestor == limit) {
SmallVector<SymbolScope, 2> scopes;
Operation *limitIt = symbol->getParentOp();
for (size_t i = 0, e = references.size(); i != e;
++i, limitIt = limitIt->getParentOp()) {
assert(limitIt->hasTrait<OpTrait::SymbolTable>());
scopes.push_back({references[i], &limitIt->getRegion(0)});
}
return scopes;
}
if (!collectedAllReferences)
return {};
return {{references.back(), limit}};
}
static SmallVector<SymbolScope, 2> collectSymbolScopes(Operation *symbol,
Region *limit) {
auto scopes = collectSymbolScopes(symbol, limit->getParentOp());
if (!scopes.empty())
scopes.back().limit = limit;
return scopes;
}
template <typename IRUnit>
static SmallVector<SymbolScope, 1> collectSymbolScopes(StringAttr symbol,
IRUnit *limit) {
return {{SymbolRefAttr::get(symbol), limit}};
}
static bool isReferencePrefixOf(SymbolRefAttr subRef, SymbolRefAttr ref) {
if (ref == subRef)
return true;
if (ref.isa<FlatSymbolRefAttr>() ||
ref.getRootReference() != subRef.getRootReference())
return false;
auto refLeafs = ref.getNestedReferences();
auto subRefLeafs = subRef.getNestedReferences();
return subRefLeafs.size() < refLeafs.size() &&
subRefLeafs == refLeafs.take_front(subRefLeafs.size());
}
template <typename FromT>
static Optional<SymbolTable::UseRange> getSymbolUsesImpl(FromT from) {
std::vector<SymbolTable::SymbolUse> uses;
auto walkFn = [&](SymbolTable::SymbolUse symbolUse) {
uses.push_back(symbolUse);
return WalkResult::advance();
};
auto result = walkSymbolUses(from, walkFn);
return result ? Optional<SymbolTable::UseRange>(std::move(uses)) : llvm::None;
}
auto SymbolTable::getSymbolUses(Operation *from) -> Optional<UseRange> {
return getSymbolUsesImpl(from);
}
auto SymbolTable::getSymbolUses(Region *from) -> Optional<UseRange> {
return getSymbolUsesImpl(MutableArrayRef<Region>(*from));
}
template <typename SymbolT, typename IRUnitT>
static Optional<SymbolTable::UseRange> getSymbolUsesImpl(SymbolT symbol,
IRUnitT *limit) {
std::vector<SymbolTable::SymbolUse> uses;
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
if (!scope.walk([&](SymbolTable::SymbolUse symbolUse) {
if (isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef()))
uses.push_back(symbolUse);
}))
return llvm::None;
}
return SymbolTable::UseRange(std::move(uses));
}
auto SymbolTable::getSymbolUses(StringAttr symbol, Operation *from)
-> Optional<UseRange> {
return getSymbolUsesImpl(symbol, from);
}
auto SymbolTable::getSymbolUses(Operation *symbol, Operation *from)
-> Optional<UseRange> {
return getSymbolUsesImpl(symbol, from);
}
auto SymbolTable::getSymbolUses(StringAttr symbol, Region *from)
-> Optional<UseRange> {
return getSymbolUsesImpl(symbol, from);
}
auto SymbolTable::getSymbolUses(Operation *symbol, Region *from)
-> Optional<UseRange> {
return getSymbolUsesImpl(symbol, from);
}
template <typename SymbolT, typename IRUnitT>
static bool symbolKnownUseEmptyImpl(SymbolT symbol, IRUnitT *limit) {
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
if (scope.walk([&](SymbolTable::SymbolUse symbolUse) {
return isReferencePrefixOf(scope.symbol, symbolUse.getSymbolRef())
? WalkResult::interrupt()
: WalkResult::advance();
}) != WalkResult::advance())
return false;
}
return true;
}
bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Operation *from) {
return symbolKnownUseEmptyImpl(symbol, from);
}
bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Operation *from) {
return symbolKnownUseEmptyImpl(symbol, from);
}
bool SymbolTable::symbolKnownUseEmpty(StringAttr symbol, Region *from) {
return symbolKnownUseEmptyImpl(symbol, from);
}
bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
return symbolKnownUseEmptyImpl(symbol, from);
}
static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr,
FlatSymbolRefAttr newLeafAttr) {
if (oldAttr.isa<FlatSymbolRefAttr>())
return newLeafAttr;
auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences());
nestedRefs.back() = newLeafAttr;
return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs);
}
template <typename SymbolT, typename IRUnitT>
static LogicalResult
replaceAllSymbolUsesImpl(SymbolT symbol, StringAttr newSymbol, IRUnitT *limit) {
FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol);
for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) {
SymbolRefAttr oldAttr = scope.symbol;
SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr);
auto walkFn = [&](Operation *op) -> Optional<WalkResult> {
auto remapAttrFn = [&](Attribute attr) -> Attribute {
if (attr == oldAttr)
return newAttr;
if (SymbolRefAttr symRef = attr.dyn_cast<SymbolRefAttr>()) {
if (isReferencePrefixOf(oldAttr, symRef)) {
auto oldNestedRefs = oldAttr.getNestedReferences();
auto nestedRefs = symRef.getNestedReferences();
if (oldNestedRefs.empty())
return SymbolRefAttr::get(newSymbol, nestedRefs);
auto newNestedRefs = llvm::to_vector<4>(nestedRefs);
newNestedRefs[oldNestedRefs.size() - 1] = newLeafAttr;
return SymbolRefAttr::get(symRef.getRootReference(), newNestedRefs);
}
}
return attr;
};
auto newDict = op->getAttrDictionary().replaceSubElements(remapAttrFn);
if (!newDict)
return WalkResult::interrupt();
op->setAttrs(newDict.template cast<DictionaryAttr>());
return WalkResult::advance();
};
if (!scope.walkSymbolTable(walkFn))
return failure();
}
return success();
}
LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
StringAttr newSymbol,
Operation *from) {
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
}
LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
StringAttr newSymbol,
Operation *from) {
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
}
LogicalResult SymbolTable::replaceAllSymbolUses(StringAttr oldSymbol,
StringAttr newSymbol,
Region *from) {
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
}
LogicalResult SymbolTable::replaceAllSymbolUses(Operation *oldSymbol,
StringAttr newSymbol,
Region *from) {
return replaceAllSymbolUsesImpl(oldSymbol, newSymbol, from);
}
Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
StringAttr symbol) {
return getSymbolTable(symbolTableOp).lookup(symbol);
}
Operation *SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
SymbolRefAttr name) {
SmallVector<Operation *, 4> symbols;
if (failed(lookupSymbolIn(symbolTableOp, name, symbols)))
return nullptr;
return symbols.back();
}
LogicalResult
SymbolTableCollection::lookupSymbolIn(Operation *symbolTableOp,
SymbolRefAttr name,
SmallVectorImpl<Operation *> &symbols) {
auto lookupFn = [this](Operation *symbolTableOp, StringAttr symbol) {
return lookupSymbolIn(symbolTableOp, symbol);
};
return lookupSymbolInImpl(symbolTableOp, name, symbols, lookupFn);
}
Operation *SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
StringAttr symbol) {
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
}
Operation *
SymbolTableCollection::lookupNearestSymbolFrom(Operation *from,
SymbolRefAttr symbol) {
Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(from);
return symbolTableOp ? lookupSymbolIn(symbolTableOp, symbol) : nullptr;
}
SymbolTable &SymbolTableCollection::getSymbolTable(Operation *op) {
auto it = symbolTables.try_emplace(op, nullptr);
if (it.second)
it.first->second = std::make_unique<SymbolTable>(op);
return *it.first->second;
}
SymbolUserMap::SymbolUserMap(SymbolTableCollection &symbolTable,
Operation *symbolTableOp)
: symbolTable(symbolTable) {
SmallVector<Operation *> symbols;
auto walkFn = [&](Operation *symbolTableOp, bool allUsesVisible) {
for (Operation &nestedOp : symbolTableOp->getRegion(0).getOps()) {
auto symbolUses = SymbolTable::getSymbolUses(&nestedOp);
assert(symbolUses && "expected uses to be valid");
for (const SymbolTable::SymbolUse &use : *symbolUses) {
symbols.clear();
(void)symbolTable.lookupSymbolIn(symbolTableOp, use.getSymbolRef(),
symbols);
for (Operation *symbolOp : symbols)
symbolToUsers[symbolOp].insert(use.getUser());
}
}
};
SymbolTable::walkSymbolTables(symbolTableOp, false,
walkFn);
}
void SymbolUserMap::replaceAllUsesWith(Operation *symbol,
StringAttr newSymbolName) {
auto it = symbolToUsers.find(symbol);
if (it == symbolToUsers.end())
return;
for (Operation *user : it->second)
(void)SymbolTable::replaceAllSymbolUses(symbol, newSymbolName, user);
Operation *newSymbol =
symbolTable.lookupSymbolIn(symbol->getParentOp(), newSymbolName);
if (newSymbol != symbol) {
auto newIt = symbolToUsers.try_emplace(newSymbol, SetVector<Operation *>{});
auto oldIt = symbolToUsers.find(symbol);
assert(oldIt != symbolToUsers.end() && "missing old users list");
if (newIt.second)
newIt.first->second = std::move(oldIt->second);
else
newIt.first->second.set_union(oldIt->second);
symbolToUsers.erase(oldIt);
}
}
ParseResult impl::parseOptionalVisibilityKeyword(OpAsmParser &parser,
NamedAttrList &attrs) {
StringRef visibility;
if (parser.parseOptionalKeyword(&visibility, {"public", "private", "nested"}))
return failure();
StringAttr visibilityAttr = parser.getBuilder().getStringAttr(visibility);
attrs.push_back(parser.getBuilder().getNamedAttr(
SymbolTable::getVisibilityAttrName(), visibilityAttr));
return success();
}
#include "mlir/IR/SymbolInterfaces.cpp.inc"