//===- Async.cpp - MLIR Async Operations ----------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Async/IR/Async.h"

#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::async;

#include "mlir/Dialect/Async/IR/AsyncOpsDialect.cpp.inc"

constexpr StringRef AsyncDialect::kAllowedToBlockAttrName;

void AsyncDialect::initialize() {
  addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
      >();
  addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
      >();
}

//===----------------------------------------------------------------------===//
/// ExecuteOp
//===----------------------------------------------------------------------===//

constexpr char kOperandSegmentSizesAttr[] = "operandSegmentSizes";

OperandRange ExecuteOp::getEntrySuccessorOperands(RegionBranchPoint point) {
  assert(point == getBodyRegion() && "invalid region index");
  return getBodyOperands();
}

bool ExecuteOp::areTypesCompatible(Type lhs, Type rhs) {
  const auto getValueOrTokenType = [](Type type) {
    if (auto value = llvm::dyn_cast<ValueType>(type))
      return value.getValueType();
    return type;
  };
  return getValueOrTokenType(lhs) == getValueOrTokenType(rhs);
}

void ExecuteOp::getSuccessorRegions(RegionBranchPoint point,
                                    SmallVectorImpl<RegionSuccessor> &regions) {
  // The `body` region branch back to the parent operation.
  if (point == getBodyRegion()) {
    regions.push_back(RegionSuccessor(getBodyResults()));
    return;
  }

  // Otherwise the successor is the body region.
  regions.push_back(
      RegionSuccessor(&getBodyRegion(), getBodyRegion().getArguments()));
}

void ExecuteOp::build(OpBuilder &builder, OperationState &result,
                      TypeRange resultTypes, ValueRange dependencies,
                      ValueRange operands, BodyBuilderFn bodyBuilder) {
  OpBuilder::InsertionGuard guard(builder);
  result.addOperands(dependencies);
  result.addOperands(operands);

  // Add derived `operandSegmentSizes` attribute based on parsed operands.
  int32_t numDependencies = dependencies.size();
  int32_t numOperands = operands.size();
  auto operandSegmentSizes =
      builder.getDenseI32ArrayAttr({numDependencies, numOperands});
  result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);

  // First result is always a token, and then `resultTypes` wrapped into
  // `async.value`.
  result.addTypes({TokenType::get(result.getContext())});
  for (Type type : resultTypes)
    result.addTypes(ValueType::get(type));

  // Add a body region with block arguments as unwrapped async value operands.
  Region *bodyRegion = result.addRegion();
  Block *bodyBlock = builder.createBlock(bodyRegion);
  for (Value operand : operands) {
    auto valueType = llvm::dyn_cast<ValueType>(operand.getType());
    bodyBlock->addArgument(valueType ? valueType.getValueType()
                                     : operand.getType(),
                           operand.getLoc());
  }

  // Create the default terminator if the builder is not provided and if the
  // expected result is empty. Otherwise, leave this to the caller
  // because we don't know which values to return from the execute op.
  if (resultTypes.empty() && !bodyBuilder) {
    builder.create<async::YieldOp>(result.location, ValueRange());
  } else if (bodyBuilder) {
    bodyBuilder(builder, result.location, bodyBlock->getArguments());
  }
}

void ExecuteOp::print(OpAsmPrinter &p) {
  // [%tokens,...]
  if (!getDependencies().empty())
    p << " [" << getDependencies() << "]";

  // (%value as %unwrapped: !async.value<!arg.type>, ...)
  if (!getBodyOperands().empty()) {
    p << " (";
    Block *entry = getBodyRegion().empty() ? nullptr : &getBodyRegion().front();
    llvm::interleaveComma(
        getBodyOperands(), p, [&, n = 0](Value operand) mutable {
          Value argument = entry ? entry->getArgument(n++) : Value();
          p << operand << " as " << argument << ": " << operand.getType();
        });
    p << ")";
  }

  // -> (!async.value<!return.type>, ...)
  p.printOptionalArrowTypeList(llvm::drop_begin(getResultTypes()));
  p.printOptionalAttrDictWithKeyword((*this)->getAttrs(),
                                     {kOperandSegmentSizesAttr});
  p << ' ';
  p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
}

ParseResult ExecuteOp::parse(OpAsmParser &parser, OperationState &result) {
  MLIRContext *ctx = result.getContext();

  // Sizes of parsed variadic operands, will be updated below after parsing.
  int32_t numDependencies = 0;

  auto tokenTy = TokenType::get(ctx);

  // Parse dependency tokens.
  if (succeeded(parser.parseOptionalLSquare())) {
    SmallVector<OpAsmParser::UnresolvedOperand, 4> tokenArgs;
    if (parser.parseOperandList(tokenArgs) ||
        parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
        parser.parseRSquare())
      return failure();

    numDependencies = tokenArgs.size();
  }

  // Parse async value operands (%value as %unwrapped : !async.value<!type>).
  SmallVector<OpAsmParser::UnresolvedOperand, 4> valueArgs;
  SmallVector<OpAsmParser::Argument, 4> unwrappedArgs;
  SmallVector<Type, 4> valueTypes;

  // Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
  auto parseAsyncValueArg = [&]() -> ParseResult {
    if (parser.parseOperand(valueArgs.emplace_back()) ||
        parser.parseKeyword("as") ||
        parser.parseArgument(unwrappedArgs.emplace_back()) ||
        parser.parseColonType(valueTypes.emplace_back()))
      return failure();

    auto valueTy = llvm::dyn_cast<ValueType>(valueTypes.back());
    unwrappedArgs.back().type = valueTy ? valueTy.getValueType() : Type();
    return success();
  };

  auto argsLoc = parser.getCurrentLocation();
  if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::OptionalParen,
                                     parseAsyncValueArg) ||
      parser.resolveOperands(valueArgs, valueTypes, argsLoc, result.operands))
    return failure();

  int32_t numOperands = valueArgs.size();

  // Add derived `operandSegmentSizes` attribute based on parsed operands.
  auto operandSegmentSizes =
      parser.getBuilder().getDenseI32ArrayAttr({numDependencies, numOperands});
  result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);

  // Parse the types of results returned from the async execute op.
  SmallVector<Type, 4> resultTypes;
  NamedAttrList attrs;
  if (parser.parseOptionalArrowTypeList(resultTypes) ||
      // Async execute first result is always a completion token.
      parser.addTypeToList(tokenTy, result.types) ||
      parser.addTypesToList(resultTypes, result.types) ||
      // Parse operation attributes.
      parser.parseOptionalAttrDictWithKeyword(attrs))
    return failure();

  result.addAttributes(attrs);

  // Parse asynchronous region.
  Region *body = result.addRegion();
  return parser.parseRegion(*body, /*arguments=*/unwrappedArgs);
}

LogicalResult ExecuteOp::verifyRegions() {
  // Unwrap async.execute value operands types.
  auto unwrappedTypes = llvm::map_range(getBodyOperands(), [](Value operand) {
    return llvm::cast<ValueType>(operand.getType()).getValueType();
  });

  // Verify that unwrapped argument types matches the body region arguments.
  if (getBodyRegion().getArgumentTypes() != unwrappedTypes)
    return emitOpError("async body region argument types do not match the "
                       "execute operation arguments types");

  return success();
}

//===----------------------------------------------------------------------===//
/// CreateGroupOp
//===----------------------------------------------------------------------===//

LogicalResult CreateGroupOp::canonicalize(CreateGroupOp op,
                                          PatternRewriter &rewriter) {
  // Find all `await_all` users of the group.
  llvm::SmallVector<AwaitAllOp> awaitAllUsers;

  auto isAwaitAll = [&](Operation *op) -> bool {
    if (AwaitAllOp awaitAll = dyn_cast<AwaitAllOp>(op)) {
      awaitAllUsers.push_back(awaitAll);
      return true;
    }
    return false;
  };

  // Check if all users of the group are `await_all` operations.
  if (!llvm::all_of(op->getUsers(), isAwaitAll))
    return failure();

  // If group is only awaited without adding anything to it, we can safely erase
  // the create operation and all users.
  for (AwaitAllOp awaitAll : awaitAllUsers)
    rewriter.eraseOp(awaitAll);
  rewriter.eraseOp(op);

  return success();
}

//===----------------------------------------------------------------------===//
/// AwaitOp
//===----------------------------------------------------------------------===//

void AwaitOp::build(OpBuilder &builder, OperationState &result, Value operand,
                    ArrayRef<NamedAttribute> attrs) {
  result.addOperands({operand});
  result.attributes.append(attrs.begin(), attrs.end());

  // Add unwrapped async.value type to the returned values types.
  if (auto valueType = llvm::dyn_cast<ValueType>(operand.getType()))
    result.addTypes(valueType.getValueType());
}

static ParseResult parseAwaitResultType(OpAsmParser &parser, Type &operandType,
                                        Type &resultType) {
  if (parser.parseType(operandType))
    return failure();

  // Add unwrapped async.value type to the returned values types.
  if (auto valueType = llvm::dyn_cast<ValueType>(operandType))
    resultType = valueType.getValueType();

  return success();
}

static void printAwaitResultType(OpAsmPrinter &p, Operation *op,
                                 Type operandType, Type resultType) {
  p << operandType;
}

LogicalResult AwaitOp::verify() {
  Type argType = getOperand().getType();

  // Awaiting on a token does not have any results.
  if (llvm::isa<TokenType>(argType) && !getResultTypes().empty())
    return emitOpError("awaiting on a token must have empty result");

  // Awaiting on a value unwraps the async value type.
  if (auto value = llvm::dyn_cast<ValueType>(argType)) {
    if (*getResultType() != value.getValueType())
      return emitOpError() << "result type " << *getResultType()
                           << " does not match async value type "
                           << value.getValueType();
  }

  return success();
}

//===----------------------------------------------------------------------===//
// FuncOp
//===----------------------------------------------------------------------===//

void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
                   FunctionType type, ArrayRef<NamedAttribute> attrs,
                   ArrayRef<DictionaryAttr> argAttrs) {
  state.addAttribute(SymbolTable::getSymbolAttrName(),
                     builder.getStringAttr(name));
  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));

  state.attributes.append(attrs.begin(), attrs.end());
  state.addRegion();

  if (argAttrs.empty())
    return;
  assert(type.getNumInputs() == argAttrs.size());
  function_interface_impl::addArgAndResultAttrs(
      builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
      getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
}

ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
  auto buildFuncType =
      [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
         function_interface_impl::VariadicFlag,
         std::string &) { return builder.getFunctionType(argTypes, results); };

  return function_interface_impl::parseFunctionOp(
      parser, result, /*allowVariadic=*/false,
      getFunctionTypeAttrName(result.name), buildFuncType,
      getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
}

void FuncOp::print(OpAsmPrinter &p) {
  function_interface_impl::printFunctionOp(
      p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
      getArgAttrsAttrName(), getResAttrsAttrName());
}

/// Check that the result type of async.func is not void and must be
/// some async token or async values.
LogicalResult FuncOp::verify() {
  auto resultTypes = getResultTypes();
  if (resultTypes.empty())
    return emitOpError()
           << "result is expected to be at least of size 1, but got "
           << resultTypes.size();

  for (unsigned i = 0, e = resultTypes.size(); i != e; ++i) {
    auto type = resultTypes[i];
    if (!llvm::isa<TokenType>(type) && !llvm::isa<ValueType>(type))
      return emitOpError() << "result type must be async value type or async "
                              "token type, but got "
                           << type;
    // We only allow AsyncToken appear as the first return value
    if (llvm::isa<TokenType>(type) && i != 0) {
      return emitOpError()
             << " results' (optional) async token type is expected "
                "to appear as the 1st return value, but got "
             << i + 1;
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
/// CallOp
//===----------------------------------------------------------------------===//

LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
  // Check that the callee attribute was specified.
  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
  if (!fnAttr)
    return emitOpError("requires a 'callee' symbol reference attribute");
  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
  if (!fn)
    return emitOpError() << "'" << fnAttr.getValue()
                         << "' does not reference a valid async function";

  // Verify that the operand and result types match the callee.
  auto fnType = fn.getFunctionType();
  if (fnType.getNumInputs() != getNumOperands())
    return emitOpError("incorrect number of operands for callee");

  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
    if (getOperand(i).getType() != fnType.getInput(i))
      return emitOpError("operand type mismatch: expected operand type ")
             << fnType.getInput(i) << ", but provided "
             << getOperand(i).getType() << " for operand number " << i;

  if (fnType.getNumResults() != getNumResults())
    return emitOpError("incorrect number of results for callee");

  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
    if (getResult(i).getType() != fnType.getResult(i)) {
      auto diag = emitOpError("result type mismatch at index ") << i;
      diag.attachNote() << "      op result types: " << getResultTypes();
      diag.attachNote() << "function result types: " << fnType.getResults();
      return diag;
    }

  return success();
}

FunctionType CallOp::getCalleeType() {
  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
}

//===----------------------------------------------------------------------===//
/// ReturnOp
//===----------------------------------------------------------------------===//

LogicalResult ReturnOp::verify() {
  auto funcOp = (*this)->getParentOfType<FuncOp>();
  ArrayRef<Type> resultTypes = funcOp.isStateful()
                                   ? funcOp.getResultTypes().drop_front()
                                   : funcOp.getResultTypes();
  // Get the underlying value types from async types returned from the
  // parent `async.func` operation.
  auto types = llvm::map_range(resultTypes, [](const Type &result) {
    return llvm::cast<ValueType>(result).getValueType();
  });

  if (getOperandTypes() != types)
    return emitOpError("operand types do not match the types returned from "
                       "the parent FuncOp");

  return success();
}

//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"

//===----------------------------------------------------------------------===//
// TableGen'd type method definitions
//===----------------------------------------------------------------------===//

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"

void ValueType::print(AsmPrinter &printer) const {
  printer << "<";
  printer.printType(getValueType());
  printer << '>';
}

Type ValueType::parse(mlir::AsmParser &parser) {
  Type ty;
  if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
    parser.emitError(parser.getNameLoc(), "failed to parse async value type");
    return Type();
  }
  return ValueType::get(ty);
}