#include "Parser.h"
#include "ParserState.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/AsmState.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/raw_ostream.h"
#include <cassert>
#include <cstdint>
#include <utility>
using namespace mlir;
using namespace mlir::detail;
namespace {
enum AffineLowPrecOp {
LNoOp,
Add,
Sub
};
enum AffineHighPrecOp {
HNoOp,
Mul,
FloorDiv,
CeilDiv,
Mod
};
class AffineParser : public Parser {
public:
AffineParser(ParserState &state, bool allowParsingSSAIds = false,
function_ref<ParseResult(bool)> parseElement = nullptr)
: Parser(state), allowParsingSSAIds(allowParsingSSAIds),
parseElement(parseElement) {}
ParseResult parseAffineMapRange(unsigned numDims, unsigned numSymbols,
AffineMap &result);
ParseResult parseAffineMapOrIntegerSetInline(AffineMap &map, IntegerSet &set);
ParseResult
parseAffineExprInline(ArrayRef<std::pair<StringRef, AffineExpr>> symbolSet,
AffineExpr &expr);
ParseResult parseIntegerSetConstraints(unsigned numDims, unsigned numSymbols,
IntegerSet &result);
ParseResult parseAffineMapOfSSAIds(AffineMap &map,
OpAsmParser::Delimiter delimiter);
ParseResult parseAffineExprOfSSAIds(AffineExpr &expr);
private:
AffineLowPrecOp consumeIfLowPrecOp();
AffineHighPrecOp consumeIfHighPrecOp();
ParseResult parseDimIdList(unsigned &numDims);
ParseResult parseSymbolIdList(unsigned &numSymbols);
ParseResult parseDimAndOptionalSymbolIdList(unsigned &numDims,
unsigned &numSymbols);
ParseResult parseIdentifierDefinition(AffineExpr idExpr);
AffineExpr parseAffineExpr();
AffineExpr parseParentheticalExpr();
AffineExpr parseNegateExpression(AffineExpr lhs);
AffineExpr parseIntegerExpr();
AffineExpr parseBareIdExpr();
AffineExpr parseSSAIdExpr(bool isSymbol);
AffineExpr parseSymbolSSAIdExpr();
AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
AffineExpr rhs, SMLoc opLoc);
AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
AffineExpr rhs);
AffineExpr parseAffineOperandExpr(AffineExpr lhs);
AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
SMLoc llhsOpLoc);
AffineExpr parseAffineConstraint(bool *isEq);
private:
bool allowParsingSSAIds;
function_ref<ParseResult(bool)> parseElement;
unsigned numDimOperands = 0;
unsigned numSymbolOperands = 0;
SmallVector<std::pair<StringRef, AffineExpr>, 4> dimsAndSymbols;
};
}
AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
AffineExpr lhs, AffineExpr rhs,
SMLoc opLoc) {
switch (op) {
case Mul:
if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
emitError(opLoc, "non-affine expression: at least one of the multiply "
"operands has to be either a constant or symbolic");
return nullptr;
}
return lhs * rhs;
case FloorDiv:
if (!rhs.isSymbolicOrConstant()) {
emitError(opLoc, "non-affine expression: right operand of floordiv "
"has to be either a constant or symbolic");
return nullptr;
}
return lhs.floorDiv(rhs);
case CeilDiv:
if (!rhs.isSymbolicOrConstant()) {
emitError(opLoc, "non-affine expression: right operand of ceildiv "
"has to be either a constant or symbolic");
return nullptr;
}
return lhs.ceilDiv(rhs);
case Mod:
if (!rhs.isSymbolicOrConstant()) {
emitError(opLoc, "non-affine expression: right operand of mod "
"has to be either a constant or symbolic");
return nullptr;
}
return lhs % rhs;
case HNoOp:
llvm_unreachable("can't create affine expression for null high prec op");
return nullptr;
}
llvm_unreachable("Unknown AffineHighPrecOp");
}
AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
AffineExpr lhs, AffineExpr rhs) {
switch (op) {
case AffineLowPrecOp::Add:
return lhs + rhs;
case AffineLowPrecOp::Sub:
return lhs - rhs;
case AffineLowPrecOp::LNoOp:
llvm_unreachable("can't create affine expression for null low prec op");
return nullptr;
}
llvm_unreachable("Unknown AffineLowPrecOp");
}
AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
switch (getToken().getKind()) {
case Token::plus:
consumeToken(Token::plus);
return AffineLowPrecOp::Add;
case Token::minus:
consumeToken(Token::minus);
return AffineLowPrecOp::Sub;
default:
return AffineLowPrecOp::LNoOp;
}
}
AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
switch (getToken().getKind()) {
case Token::star:
consumeToken(Token::star);
return Mul;
case Token::kw_floordiv:
consumeToken(Token::kw_floordiv);
return FloorDiv;
case Token::kw_ceildiv:
consumeToken(Token::kw_ceildiv);
return CeilDiv;
case Token::kw_mod:
consumeToken(Token::kw_mod);
return Mod;
default:
return HNoOp;
}
}
AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
AffineHighPrecOp llhsOp,
SMLoc llhsOpLoc) {
AffineExpr lhs = parseAffineOperandExpr(llhs);
if (!lhs)
return nullptr;
auto opLoc = getToken().getLoc();
if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
if (llhs) {
AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
if (!expr)
return nullptr;
return parseAffineHighPrecOpExpr(expr, op, opLoc);
}
return parseAffineHighPrecOpExpr(lhs, op, opLoc);
}
if (llhs)
return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
return lhs;
}
AffineExpr AffineParser::parseParentheticalExpr() {
if (parseToken(Token::l_paren, "expected '('"))
return nullptr;
if (getToken().is(Token::r_paren))
return emitError("no expression inside parentheses"), nullptr;
auto expr = parseAffineExpr();
if (!expr || parseToken(Token::r_paren, "expected ')'"))
return nullptr;
return expr;
}
AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
if (parseToken(Token::minus, "expected '-'"))
return nullptr;
AffineExpr operand = parseAffineOperandExpr(lhs);
if (!operand)
return emitError("missing operand of negation"), nullptr;
return (-1) * operand;
}
static bool isIdentifier(const Token &token) {
return token.isAny(Token::bare_identifier, Token::inttype) ||
token.isKeyword();
}
AffineExpr AffineParser::parseBareIdExpr() {
if (!isIdentifier(getToken()))
return emitWrongTokenError("expected bare identifier"), nullptr;
StringRef sRef = getTokenSpelling();
for (auto entry : dimsAndSymbols) {
if (entry.first == sRef) {
consumeToken();
return entry.second;
}
}
return emitWrongTokenError("use of undeclared identifier"), nullptr;
}
AffineExpr AffineParser::parseSSAIdExpr(bool isSymbol) {
if (!allowParsingSSAIds)
return emitWrongTokenError("unexpected ssa identifier"), nullptr;
if (getToken().isNot(Token::percent_identifier))
return emitWrongTokenError("expected ssa identifier"), nullptr;
auto name = getTokenSpelling();
for (auto entry : dimsAndSymbols) {
if (entry.first == name) {
consumeToken(Token::percent_identifier);
return entry.second;
}
}
if (parseElement(isSymbol))
return nullptr;
auto idExpr = isSymbol
? getAffineSymbolExpr(numSymbolOperands++, getContext())
: getAffineDimExpr(numDimOperands++, getContext());
dimsAndSymbols.push_back({name, idExpr});
return idExpr;
}
AffineExpr AffineParser::parseSymbolSSAIdExpr() {
if (parseToken(Token::kw_symbol, "expected symbol keyword") ||
parseToken(Token::l_paren, "expected '(' at start of SSA symbol"))
return nullptr;
AffineExpr symbolExpr = parseSSAIdExpr(true);
if (!symbolExpr)
return nullptr;
if (parseToken(Token::r_paren, "expected ')' at end of SSA symbol"))
return nullptr;
return symbolExpr;
}
AffineExpr AffineParser::parseIntegerExpr() {
auto val = getToken().getUInt64IntegerValue();
if (!val.has_value() || (int64_t)*val < 0)
return emitError("constant too large for index"), nullptr;
consumeToken(Token::integer);
return builder.getAffineConstantExpr((int64_t)*val);
}
AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
switch (getToken().getKind()) {
case Token::kw_symbol:
return parseSymbolSSAIdExpr();
case Token::percent_identifier:
return parseSSAIdExpr(false);
case Token::integer:
return parseIntegerExpr();
case Token::l_paren:
return parseParentheticalExpr();
case Token::minus:
return parseNegateExpression(lhs);
case Token::kw_ceildiv:
case Token::kw_floordiv:
case Token::kw_mod:
return parseBareIdExpr();
case Token::plus:
case Token::star:
if (lhs)
emitError("missing right operand of binary operator");
else
emitError("missing left operand of binary operator");
return nullptr;
default:
if (isIdentifier(getToken()))
return parseBareIdExpr();
if (lhs)
emitError("missing right operand of binary operator");
else
emitError("expected affine expression");
return nullptr;
}
}
AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
AffineLowPrecOp llhsOp) {
AffineExpr lhs;
if (!(lhs = parseAffineOperandExpr(llhs)))
return nullptr;
if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
if (llhs) {
AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
return parseAffineLowPrecOpExpr(sum, lOp);
}
return parseAffineLowPrecOpExpr(lhs, lOp);
}
auto opLoc = getToken().getLoc();
if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
if (!highRes)
return nullptr;
AffineExpr expr =
llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
return parseAffineLowPrecOpExpr(expr, nextOp);
return expr;
}
if (llhs)
return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
return lhs;
}
AffineExpr AffineParser::parseAffineExpr() {
return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
}
ParseResult AffineParser::parseIdentifierDefinition(AffineExpr idExpr) {
if (!isIdentifier(getToken()))
return emitWrongTokenError("expected bare identifier");
auto name = getTokenSpelling();
for (auto entry : dimsAndSymbols) {
if (entry.first == name)
return emitError("redefinition of identifier '" + name + "'");
}
consumeToken();
dimsAndSymbols.push_back({name, idExpr});
return success();
}
ParseResult AffineParser::parseDimIdList(unsigned &numDims) {
auto parseElt = [&]() -> ParseResult {
auto dimension = getAffineDimExpr(numDims++, getContext());
return parseIdentifierDefinition(dimension);
};
return parseCommaSeparatedList(Delimiter::Paren, parseElt,
" in dimensional identifier list");
}
ParseResult AffineParser::parseSymbolIdList(unsigned &numSymbols) {
auto parseElt = [&]() -> ParseResult {
auto symbol = getAffineSymbolExpr(numSymbols++, getContext());
return parseIdentifierDefinition(symbol);
};
return parseCommaSeparatedList(Delimiter::Square, parseElt,
" in symbol list");
}
ParseResult
AffineParser::parseDimAndOptionalSymbolIdList(unsigned &numDims,
unsigned &numSymbols) {
if (parseDimIdList(numDims)) {
return failure();
}
if (!getToken().is(Token::l_square)) {
numSymbols = 0;
return success();
}
return parseSymbolIdList(numSymbols);
}
ParseResult AffineParser::parseAffineMapOrIntegerSetInline(AffineMap &map,
IntegerSet &set) {
unsigned numDims = 0, numSymbols = 0;
if (parseDimAndOptionalSymbolIdList(numDims, numSymbols))
return failure();
if (consumeIf(Token::arrow))
return parseAffineMapRange(numDims, numSymbols, map);
if (parseToken(Token::colon, "expected '->' or ':'"))
return failure();
return parseIntegerSetConstraints(numDims, numSymbols, set);
}
ParseResult AffineParser::parseAffineExprInline(
ArrayRef<std::pair<StringRef, AffineExpr>> symbolSet, AffineExpr &expr) {
dimsAndSymbols.assign(symbolSet.begin(), symbolSet.end());
expr = parseAffineExpr();
return success(expr != nullptr);
}
ParseResult
AffineParser::parseAffineMapOfSSAIds(AffineMap &map,
OpAsmParser::Delimiter delimiter) {
SmallVector<AffineExpr, 4> exprs;
auto parseElt = [&]() -> ParseResult {
auto elt = parseAffineExpr();
exprs.push_back(elt);
return elt ? success() : failure();
};
if (parseCommaSeparatedList(delimiter, parseElt, " in affine map"))
return failure();
map = AffineMap::get(numDimOperands, dimsAndSymbols.size() - numDimOperands,
exprs, getContext());
return success();
}
ParseResult AffineParser::parseAffineExprOfSSAIds(AffineExpr &expr) {
expr = parseAffineExpr();
return success(expr != nullptr);
}
ParseResult AffineParser::parseAffineMapRange(unsigned numDims,
unsigned numSymbols,
AffineMap &result) {
SmallVector<AffineExpr, 4> exprs;
auto parseElt = [&]() -> ParseResult {
auto elt = parseAffineExpr();
ParseResult res = elt ? success() : failure();
exprs.push_back(elt);
return res;
};
if (parseCommaSeparatedList(Delimiter::Paren, parseElt,
" in affine map range"))
return failure();
result = AffineMap::get(numDims, numSymbols, exprs, getContext());
return success();
}
AffineExpr AffineParser::parseAffineConstraint(bool *isEq) {
AffineExpr lhsExpr = parseAffineExpr();
if (!lhsExpr)
return nullptr;
if (consumeIf(Token::greater) && consumeIf(Token::equal)) {
AffineExpr rhsExpr = parseAffineExpr();
if (!rhsExpr)
return nullptr;
*isEq = false;
return lhsExpr - rhsExpr;
}
if (consumeIf(Token::less) && consumeIf(Token::equal)) {
AffineExpr rhsExpr = parseAffineExpr();
if (!rhsExpr)
return nullptr;
*isEq = false;
return rhsExpr - lhsExpr;
}
if (consumeIf(Token::equal) && consumeIf(Token::equal)) {
AffineExpr rhsExpr = parseAffineExpr();
if (!rhsExpr)
return nullptr;
*isEq = true;
return lhsExpr - rhsExpr;
}
return emitError("expected '== affine-expr' or '>= affine-expr' at end of "
"affine constraint"),
nullptr;
}
ParseResult AffineParser::parseIntegerSetConstraints(unsigned numDims,
unsigned numSymbols,
IntegerSet &result) {
SmallVector<AffineExpr, 4> constraints;
SmallVector<bool, 4> isEqs;
auto parseElt = [&]() -> ParseResult {
bool isEq;
auto elt = parseAffineConstraint(&isEq);
ParseResult res = elt ? success() : failure();
if (elt) {
constraints.push_back(elt);
isEqs.push_back(isEq);
}
return res;
};
if (parseCommaSeparatedList(Delimiter::Paren, parseElt,
" in integer set constraint list"))
return failure();
if (constraints.empty()) {
auto zero = getAffineConstantExpr(0, getContext());
result = IntegerSet::get(numDims, numSymbols, zero, true);
return success();
}
result = IntegerSet::get(numDims, numSymbols, constraints, isEqs);
return success();
}
ParseResult Parser::parseAffineMapOrIntegerSetReference(AffineMap &map,
IntegerSet &set) {
return AffineParser(state).parseAffineMapOrIntegerSetInline(map, set);
}
ParseResult Parser::parseAffineMapReference(AffineMap &map) {
SMLoc curLoc = getToken().getLoc();
IntegerSet set;
if (parseAffineMapOrIntegerSetReference(map, set))
return failure();
if (set)
return emitError(curLoc, "expected AffineMap, but got IntegerSet");
return success();
}
ParseResult Parser::parseAffineExprReference(
ArrayRef<std::pair<StringRef, AffineExpr>> symbolSet, AffineExpr &expr) {
return AffineParser(state).parseAffineExprInline(symbolSet, expr);
}
ParseResult Parser::parseIntegerSetReference(IntegerSet &set) {
SMLoc curLoc = getToken().getLoc();
AffineMap map;
if (parseAffineMapOrIntegerSetReference(map, set))
return failure();
if (map)
return emitError(curLoc, "expected IntegerSet, but got AffineMap");
return success();
}
ParseResult
Parser::parseAffineMapOfSSAIds(AffineMap &map,
function_ref<ParseResult(bool)> parseElement,
OpAsmParser::Delimiter delimiter) {
return AffineParser(state, true, parseElement)
.parseAffineMapOfSSAIds(map, delimiter);
}
ParseResult
Parser::parseAffineExprOfSSAIds(AffineExpr &expr,
function_ref<ParseResult(bool)> parseElement) {
return AffineParser(state, true, parseElement)
.parseAffineExprOfSSAIds(expr);
}
static void parseAffineMapOrIntegerSet(StringRef inputStr, MLIRContext *context,
AffineMap &map, IntegerSet &set) {
llvm::SourceMgr sourceMgr;
auto memBuffer = llvm::MemoryBuffer::getMemBuffer(
inputStr, "<mlir_parser_buffer>",
false);
sourceMgr.AddNewSourceBuffer(std::move(memBuffer), SMLoc());
SymbolState symbolState;
ParserConfig config(context);
ParserState state(sourceMgr, config, symbolState, nullptr,
nullptr);
Parser parser(state);
SourceMgrDiagnosticHandler handler(sourceMgr, context, llvm::errs());
if (parser.parseAffineMapOrIntegerSetReference(map, set))
return;
Token endTok = parser.getToken();
if (endTok.isNot(Token::eof)) {
parser.emitError(endTok.getLoc(), "encountered unexpected token");
return;
}
}
AffineMap mlir::parseAffineMap(StringRef inputStr, MLIRContext *context) {
AffineMap map;
IntegerSet set;
parseAffineMapOrIntegerSet(inputStr, context, map, set);
assert(!set &&
"expected string to represent AffineMap, but got IntegerSet instead");
return map;
}
IntegerSet mlir::parseIntegerSet(StringRef inputStr, MLIRContext *context) {
AffineMap map;
IntegerSet set;
parseAffineMapOrIntegerSet(inputStr, context, map, set);
assert(!map &&
"expected string to represent IntegerSet, but got AffineMap instead");
return set;
}