#include <utility>
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "llvm/ADT/SmallPtrSet.h"
using namespace mlir;
#include "mlir/Interfaces/ControlFlowInterfaces.cpp.inc"
SuccessorOperands::SuccessorOperands(MutableOperandRange forwardedOperands)
: producedOperandCount(0), forwardedOperands(std::move(forwardedOperands)) {
}
SuccessorOperands::SuccessorOperands(unsigned int producedOperandCount,
MutableOperandRange forwardedOperands)
: producedOperandCount(producedOperandCount),
forwardedOperands(std::move(forwardedOperands)) {}
Optional<BlockArgument>
detail::getBranchSuccessorArgument(const SuccessorOperands &operands,
unsigned operandIndex, Block *successor) {
OperandRange forwardedOperands = operands.getForwardedOperands();
if (forwardedOperands.empty())
return llvm::None;
unsigned operandsStart = forwardedOperands.getBeginOperandIndex();
if (operandIndex < operandsStart ||
operandIndex >= (operandsStart + forwardedOperands.size()))
return llvm::None;
unsigned argIndex =
operands.getProducedOperandCount() + operandIndex - operandsStart;
return successor->getArgument(argIndex);
}
LogicalResult
detail::verifyBranchSuccessorOperands(Operation *op, unsigned succNo,
const SuccessorOperands &operands) {
unsigned operandCount = operands.size();
Block *destBB = op->getSuccessor(succNo);
if (operandCount != destBB->getNumArguments())
return op->emitError() << "branch has " << operandCount
<< " operands for successor #" << succNo
<< ", but target block has "
<< destBB->getNumArguments();
for (unsigned i = operands.getProducedOperandCount(); i != operandCount;
++i) {
if (!cast<BranchOpInterface>(op).areTypesCompatible(
operands[i].getType(), destBB->getArgument(i).getType()))
return op->emitError() << "type mismatch for bb argument #" << i
<< " of successor #" << succNo;
}
return success();
}
static LogicalResult
verifyTypesAlongAllEdges(Operation *op, Optional<unsigned> sourceNo,
function_ref<Optional<TypeRange>(Optional<unsigned>)>
getInputsTypesForRegion) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
SmallVector<RegionSuccessor, 2> successors;
regionInterface.getSuccessorRegions(sourceNo, successors);
for (RegionSuccessor &succ : successors) {
Optional<unsigned> succRegionNo;
if (!succ.isParent())
succRegionNo = succ.getSuccessor()->getRegionNumber();
auto printEdgeName = [&](InFlightDiagnostic &diag) -> InFlightDiagnostic & {
diag << "from ";
if (sourceNo)
diag << "Region #" << sourceNo.value();
else
diag << "parent operands";
diag << " to ";
if (succRegionNo)
diag << "Region #" << succRegionNo.value();
else
diag << "parent results";
return diag;
};
Optional<TypeRange> sourceTypes = getInputsTypesForRegion(succRegionNo);
if (!sourceTypes.has_value())
continue;
TypeRange succInputsTypes = succ.getSuccessorInputs().getTypes();
if (sourceTypes->size() != succInputsTypes.size()) {
InFlightDiagnostic diag = op->emitOpError(" region control flow edge ");
return printEdgeName(diag) << ": source has " << sourceTypes->size()
<< " operands, but target successor needs "
<< succInputsTypes.size();
}
for (const auto &typesIdx :
llvm::enumerate(llvm::zip(*sourceTypes, succInputsTypes))) {
Type sourceType = std::get<0>(typesIdx.value());
Type inputType = std::get<1>(typesIdx.value());
if (!regionInterface.areTypesCompatible(sourceType, inputType)) {
InFlightDiagnostic diag = op->emitOpError(" along control flow edge ");
return printEdgeName(diag)
<< ": source type #" << typesIdx.index() << " " << sourceType
<< " should match input type #" << typesIdx.index() << " "
<< inputType;
}
}
}
return success();
}
LogicalResult detail::verifyTypesAlongControlFlowEdges(Operation *op) {
auto regionInterface = cast<RegionBranchOpInterface>(op);
auto inputTypesFromParent = [&](Optional<unsigned> regionNo) -> TypeRange {
return regionInterface.getSuccessorEntryOperands(regionNo).getTypes();
};
if (failed(verifyTypesAlongAllEdges(op, llvm::None, inputTypesFromParent)))
return failure();
assert(op->getNumRegions() != 0);
auto areTypesCompatible = [&](TypeRange lhs, TypeRange rhs) {
if (lhs.size() != rhs.size())
return false;
for (auto types : llvm::zip(lhs, rhs)) {
if (!regionInterface.areTypesCompatible(std::get<0>(types),
std::get<1>(types))) {
return false;
}
}
return true;
};
for (unsigned regionNo : llvm::seq(0U, op->getNumRegions())) {
Region ®ion = op->getRegion(regionNo);
Optional<OperandRange> regionReturnOperands;
for (Block &block : region) {
Operation *terminator = block.getTerminator();
auto terminatorOperands =
getRegionBranchSuccessorOperands(terminator, regionNo);
if (!terminatorOperands)
continue;
if (!regionReturnOperands) {
regionReturnOperands = terminatorOperands;
continue;
}
if (!areTypesCompatible(regionReturnOperands->getTypes(),
terminatorOperands->getTypes()))
return op->emitOpError("Region #")
<< regionNo
<< " operands mismatch between return-like terminators";
}
auto inputTypesFromRegion =
[&](Optional<unsigned> regionNo) -> Optional<TypeRange> {
if (!regionReturnOperands)
return llvm::None;
return TypeRange(regionReturnOperands->getTypes());
};
if (failed(verifyTypesAlongAllEdges(op, regionNo, inputTypesFromRegion)))
return failure();
}
return success();
}
static bool isRegionReachable(Region *begin, Region *r) {
assert(begin->getParentOp() == r->getParentOp() &&
"expected that both regions belong to the same op");
auto op = cast<RegionBranchOpInterface>(begin->getParentOp());
SmallVector<bool> visited(op->getNumRegions(), false);
visited[begin->getRegionNumber()] = true;
SmallVector<unsigned> worklist;
auto enqueueAllSuccessors = [&](unsigned index) {
SmallVector<RegionSuccessor> successors;
op.getSuccessorRegions(index, successors);
for (RegionSuccessor successor : successors)
if (!successor.isParent())
worklist.push_back(successor.getSuccessor()->getRegionNumber());
};
enqueueAllSuccessors(begin->getRegionNumber());
while (!worklist.empty()) {
unsigned nextRegion = worklist.pop_back_val();
if (nextRegion == r->getRegionNumber())
return true;
if (visited[nextRegion])
continue;
visited[nextRegion] = true;
enqueueAllSuccessors(nextRegion);
}
return false;
}
bool mlir::insideMutuallyExclusiveRegions(Operation *a, Operation *b) {
assert(a && "expected non-empty operation");
assert(b && "expected non-empty operation");
auto branchOp = a->getParentOfType<RegionBranchOpInterface>();
while (branchOp) {
if (!branchOp->isProperAncestor(b)) {
branchOp = branchOp->getParentOfType<RegionBranchOpInterface>();
continue;
}
Region *regionA = nullptr, *regionB = nullptr;
for (Region &r : branchOp->getRegions()) {
if (r.findAncestorOpInRegion(*a)) {
assert(!regionA && "already found a region for a");
regionA = &r;
}
if (r.findAncestorOpInRegion(*b)) {
assert(!regionB && "already found a region for b");
regionB = &r;
}
}
assert(regionA && regionB && "could not find region of op");
return regionA != regionB && !isRegionReachable(regionA, regionB) &&
!isRegionReachable(regionB, regionA);
}
return false;
}
bool RegionBranchOpInterface::isRepetitiveRegion(unsigned index) {
Region *region = &getOperation()->getRegion(index);
return isRegionReachable(region, region);
}
void RegionBranchOpInterface::getSuccessorRegions(
Optional<unsigned> index, SmallVectorImpl<RegionSuccessor> ®ions) {
unsigned numInputs = 0;
if (index) {
for (Block &block : getOperation()->getRegion(*index)) {
Operation *terminator = block.getTerminator();
if (getRegionBranchSuccessorOperands(terminator, *index)) {
numInputs = terminator->getNumOperands();
break;
}
}
} else {
numInputs = getOperation()->getNumOperands();
}
SmallVector<Attribute, 2> operands(numInputs, nullptr);
getSuccessorRegions(index, operands, regions);
}
Region *mlir::getEnclosingRepetitiveRegion(Operation *op) {
while (Region *region = op->getParentRegion()) {
op = region->getParentOp();
if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
return region;
}
return nullptr;
}
Region *mlir::getEnclosingRepetitiveRegion(Value value) {
Region *region = value.getParentRegion();
while (region) {
Operation *op = region->getParentOp();
if (auto branchOp = dyn_cast<RegionBranchOpInterface>(op))
if (branchOp.isRepetitiveRegion(region->getRegionNumber()))
return region;
region = op->getParentRegion();
}
return nullptr;
}
bool mlir::isRegionReturnLike(Operation *operation) {
return dyn_cast<RegionBranchTerminatorOpInterface>(operation) ||
operation->hasTrait<OpTrait::ReturnLike>();
}
Optional<MutableOperandRange>
mlir::getMutableRegionBranchSuccessorOperands(Operation *operation,
Optional<unsigned> regionIndex) {
if (auto regionTerminatorInterface =
dyn_cast<RegionBranchTerminatorOpInterface>(operation))
return regionTerminatorInterface.getMutableSuccessorOperands(regionIndex);
if (operation->hasTrait<OpTrait::ReturnLike>())
return MutableOperandRange(operation);
return llvm::None;
}
Optional<OperandRange>
mlir::getRegionBranchSuccessorOperands(Operation *operation,
Optional<unsigned> regionIndex) {
auto range = getMutableRegionBranchSuccessorOperands(operation, regionIndex);
return range ? Optional<OperandRange>(*range) : llvm::None;
}