#include "mlir/Dialect/Tosa/Transforms/Passes.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/Dialect/Tosa/Utils/ShapeUtils.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
namespace tosa {
#define GEN_PASS_DEF_TOSAINFERSHAPES
#include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
}
}
using namespace mlir;
using namespace mlir::tosa;
namespace {
bool canBeRefined(Operation *user) {
if (!user->getDialect())
return false;
return user->getDialect()->getTypeID() == TypeID::get<TosaDialect>() ||
isa<InferTypeOpInterface, InferShapedTypeOpInterface>(user);
}
class TypeModificationState {
public:
TypeModificationState() = default;
~TypeModificationState() {
assert(oldTypes.empty() && "unhandled type modifications");
}
void setType(Value value, Type type) {
if (value.getType() != type) {
oldTypes.emplace_back(value, value.getType());
value.setType(type);
}
}
void rollBack() {
for (auto [value, type] : oldTypes)
value.setType(type);
oldTypes.clear();
}
void commit() {
for (auto [value, oldType] : oldTypes) {
tensor::CastOp castedValue;
for (auto &use : value.getUses()) {
if (canBeRefined(use.getOwner()))
continue;
if (!castedValue) {
ImplicitLocOpBuilder builder{value.getLoc(), use.getOwner()};
castedValue = builder.create<tensor::CastOp>(oldType, value);
}
use.set(castedValue);
}
}
oldTypes.clear();
}
private:
llvm::SmallVector<std::pair<Value, Type>> oldTypes;
};
void propagateShapesInRegion(Region ®ion, TypeModificationState &state);
void propagateShapesToTosaIf(Operation &op, TypeModificationState &state) {
IfOp ifOp = dyn_cast<IfOp>(op);
if (!ifOp)
return;
for (auto ®ion : op.getRegions()) {
Block &frontBlock = region.front();
if (frontBlock.getNumArguments() + 1 != ifOp.getNumOperands())
return;
for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
auto blockArg = frontBlock.getArgument(i - 1);
auto oldType = cast<ShapedType>(blockArg.getType());
if (inferredTy.hasRank()) {
Type newType = oldType.clone(inferredTy.getShape());
state.setType(blockArg, newType);
}
}
for (int i = 0, e = frontBlock.getNumArguments(); i < e; i++) {
ValueKnowledge operandKnowledge = ValueKnowledge::getKnowledgeFromType(
ifOp.getOperand(i + 1).getType());
ValueKnowledge blockKnowledge = ValueKnowledge::getKnowledgeFromType(
frontBlock.getArgument(i).getType());
ValueKnowledge joinedKnowledge =
ValueKnowledge::join(operandKnowledge, blockKnowledge);
if (!joinedKnowledge)
continue;
state.setType(frontBlock.getArgument(i), joinedKnowledge.getType());
}
propagateShapesInRegion(region, state);
}
}
void propagateShapesToTosaWhile(Operation &op, TypeModificationState &state) {
WhileOp whileOp = dyn_cast<WhileOp>(op);
if (!whileOp)
return;
SmallVector<Type> argTypes = llvm::to_vector(op.getOperandTypes());
bool hasNewTypes = true;
while (hasNewTypes) {
TypeModificationState localState;
Region &bodyRegion = op.getRegion(1);
Block &block = bodyRegion.front();
for (int i = 0, s = argTypes.size(); i < s; i++) {
localState.setType(block.getArgument(i), argTypes[i]);
}
propagateShapesInRegion(bodyRegion, localState);
llvm::SmallVector<YieldOp> yieldOps;
for (auto &block : bodyRegion)
if (auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
yieldOps.push_back(yieldOp);
assert(yieldOps.size() == 1 && "missing or non-unique yield op");
llvm::SmallVector<ValueKnowledge> yieldTypeInfo;
for (auto ty : argTypes) {
yieldTypeInfo.push_back(ValueKnowledge::getKnowledgeFromType(ty));
}
for (auto yieldOp : yieldOps) {
for (const auto &it : llvm::enumerate(yieldOp.getOperands())) {
auto newKnowledge =
ValueKnowledge::getKnowledgeFromType(it.value().getType());
yieldTypeInfo[it.index()] =
ValueKnowledge::meet(yieldTypeInfo[it.index()], newKnowledge);
}
}
if (yieldTypeInfo.size() != argTypes.size()) {
op.emitWarning("has a tosa.yield with the incorrect number of operands");
return;
}
hasNewTypes = false;
for (int i = 0, s = yieldTypeInfo.size(); i < s; i++) {
Type newType = yieldTypeInfo[i].getType();
hasNewTypes |= (newType != argTypes[i]);
argTypes[i] = newType;
}
localState.rollBack();
}
for (auto ®ion : op.getRegions()) {
for (unsigned int i = 0, s = argTypes.size(); i < s; i++) {
state.setType(region.front().getArgument(i), argTypes[i]);
}
propagateShapesInRegion(region, state);
}
}
void propagateShapesInRegion(Region ®ion, TypeModificationState &state) {
Dialect *tosaDialect = region.getContext()->getLoadedDialect<TosaDialect>();
for (auto &block : region) {
for (Operation &op : block) {
if (op.getDialect() != tosaDialect)
continue;
propagateShapesToTosaIf(op, state);
propagateShapesToTosaWhile(op, state);
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
if (!shapeInterface)
continue;
SmallVector<ShapedTypeComponents> returnedShapes;
if (shapeInterface
.inferReturnTypeComponents(
op.getContext(), op.getLoc(), op.getOperands(),
op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
op.getRegions(), returnedShapes)
.succeeded()) {
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
Value result = std::get<0>(it);
ShapedTypeComponents predictedShape = std::get<1>(it);
Type resultTy = result.getType();
auto currentKnowledge =
ValueKnowledge::getKnowledgeFromType(resultTy);
auto inferredKnowledge = ValueKnowledge::getPessimisticValueState();
inferredKnowledge.dtype = cast<ShapedType>(resultTy).getElementType();
inferredKnowledge.hasRank = predictedShape.hasRank();
if (predictedShape.hasRank()) {
for (auto dim : predictedShape.getDims()) {
inferredKnowledge.sizes.push_back(dim);
}
}
auto newKnowledge =
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
if (!newKnowledge)
continue;
state.setType(result, newKnowledge.getType());
}
}
}
}
}
struct TosaInferShapes
: public tosa::impl::TosaInferShapesBase<TosaInferShapes> {
public:
void runOnOperation() override {
func::FuncOp func = getOperation();
TypeModificationState state;
propagateShapesInRegion(func.getBody(), state);
state.commit();
}
};
}
std::unique_ptr<Pass> mlir::tosa::createTosaInferShapesPass() {
return std::make_unique<TosaInferShapes>();
}