#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
using namespace mlir;
using namespace tosa;
namespace {
class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
public:
using OpRewritePattern<tosa::VariableOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::VariableOp op,
PatternRewriter &rewriter) const final {
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
op.getLoc(), op.getName(), op.getType(), true,
op.getInitialValueAttr(), nullptr);
newVariable.setPrivate();
rewriter.replaceOp(op, newVariable);
return success();
}
};
class VariableWriteOpConverter
: public OpRewritePattern<tosa::VariableWriteOp> {
public:
using OpRewritePattern<tosa::VariableWriteOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
PatternRewriter &rewriter) const final {
auto globalSymbolRef =
SymbolRefAttr::get(rewriter.getContext(), op.getName());
auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
op.getLoc(), globalSymbolRef, op.getValue());
rewriter.replaceOp(op, newVariableWrite);
return success();
}
};
class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
public:
using OpRewritePattern<tosa::VariableReadOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tosa::VariableReadOp op,
PatternRewriter &rewriter) const final {
auto globalSymbolRef =
SymbolRefAttr::get(rewriter.getContext(), op.getName());
auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
op.getLoc(), op.getType(), globalSymbolRef);
rewriter.replaceOp(op, newVariableRead);
return success();
}
};
}
void mlir::tosa::populateTosaToMLProgramConversionPatterns(
RewritePatternSet *patterns) {
patterns->add<VariableOpConverter, VariableWriteOpConverter,
VariableReadOpConverter>(patterns->getContext());
}