//===- MemoryOps.cpp - MLIR SPIR-V Memory Ops  ----------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the memory operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"

#include "SPIRVOpUtils.h"
#include "SPIRVParsingUtils.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/IR/Diagnostics.h"

#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Casting.h"

using namespace mlir::spirv::AttrNames;

namespace mlir::spirv {

/// Parses optional memory access (a.k.a. memory operand) attributes attached to
/// a memory access operand/pointer. Specifically, parses the following syntax:
///     (`[` memory-access `]`)?
/// where:
///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
///         integer-literal | `"NonTemporal"`
template <typename MemoryOpTy>
ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
                                        OperationState &state) {
  // Parse an optional list of attributes staring with '['
  if (parser.parseOptionalLSquare()) {
    // Nothing to do
    return success();
  }

  spirv::MemoryAccess memoryAccessAttr;
  StringAttr memoryAccessAttrName =
      MemoryOpTy::getMemoryAccessAttrName(state.name);
  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
          memoryAccessAttr, parser, state, memoryAccessAttrName))
    return failure();

  if (spirv::bitEnumContainsAll(memoryAccessAttr,
                                spirv::MemoryAccess::Aligned)) {
    // Parse integer attribute for alignment.
    Attribute alignmentAttr;
    StringAttr alignmentAttrName = MemoryOpTy::getAlignmentAttrName(state.name);
    Type i32Type = parser.getBuilder().getIntegerType(32);
    if (parser.parseComma() ||
        parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
                              state.attributes)) {
      return failure();
    }
  }
  return parser.parseRSquare();
}

// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
template <typename MemoryOpTy>
static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
                                                     OperationState &state) {
  // Parse an optional list of attributes staring with '['
  if (parser.parseOptionalLSquare()) {
    // Nothing to do
    return success();
  }

  spirv::MemoryAccess memoryAccessAttr;
  StringRef memoryAccessAttrName =
      MemoryOpTy::getSourceMemoryAccessAttrName(state.name);
  if (spirv::parseEnumStrAttr<spirv::MemoryAccessAttr>(
          memoryAccessAttr, parser, state, memoryAccessAttrName))
    return failure();

  if (spirv::bitEnumContainsAll(memoryAccessAttr,
                                spirv::MemoryAccess::Aligned)) {
    // Parse integer attribute for alignment.
    Attribute alignmentAttr;
    StringAttr alignmentAttrName =
        MemoryOpTy::getSourceAlignmentAttrName(state.name);
    Type i32Type = parser.getBuilder().getIntegerType(32);
    if (parser.parseComma() ||
        parser.parseAttribute(alignmentAttr, i32Type, alignmentAttrName,
                              state.attributes)) {
      return failure();
    }
  }
  return parser.parseRSquare();
}

// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
template <typename MemoryOpTy>
static void printSourceMemoryAccessAttribute(
    MemoryOpTy memoryOp, OpAsmPrinter &printer,
    SmallVectorImpl<StringRef> &elidedAttrs,
    std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
    std::optional<uint32_t> alignmentAttrValue = std::nullopt) {

  printer << ", ";

  // Print optional memory access attribute.
  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                              : memoryOp.getMemoryAccess())) {
    elidedAttrs.push_back(memoryOp.getSourceMemoryAccessAttrName());

    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";

    if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
      // Print integer alignment attribute.
      if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                               : memoryOp.getAlignment())) {
        elidedAttrs.push_back(memoryOp.getSourceAlignmentAttrName());
        printer << ", " << *alignment;
      }
    }
    printer << "]";
  }
  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
}

template <typename MemoryOpTy>
static void printMemoryAccessAttribute(
    MemoryOpTy memoryOp, OpAsmPrinter &printer,
    SmallVectorImpl<StringRef> &elidedAttrs,
    std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
    std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
  // Print optional memory access attribute.
  if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
                                              : memoryOp.getMemoryAccess())) {
    elidedAttrs.push_back(memoryOp.getMemoryAccessAttrName());

    printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";

    if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
      // Print integer alignment attribute.
      if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
                                               : memoryOp.getAlignment())) {
        elidedAttrs.push_back(memoryOp.getAlignmentAttrName());
        printer << ", " << *alignment;
      }
    }
    printer << "]";
  }
  elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
}

template <typename LoadStoreOpTy>
static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
                                                   Value val) {
  // ODS already checks ptr is spirv::PointerType. Just check that the pointee
  // type of the pointer and the type of the value are the same
  //
  // TODO: Check that the value type satisfies restrictions of
  // SPIR-V OpLoad/OpStore operations
  if (val.getType() !=
      llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
    return op.emitOpError("mismatch in result type and pointer type");
  }
  return success();
}

template <typename MemoryOpTy>
static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
  // ODS checks for attributes values. Just need to verify that if the
  // memory-access attribute is Aligned, then the alignment attribute must be
  // present.
  auto *op = memoryOp.getOperation();
  auto memAccessAttr = op->getAttr(memoryOp.getMemoryAccessAttrName());
  if (!memAccessAttr) {
    // Alignment attribute shouldn't be present if memory access attribute is
    // not present.
    if (op->getAttr(memoryOp.getAlignmentAttrName())) {
      return memoryOp.emitOpError(
          "invalid alignment specification without aligned memory access "
          "specification");
    }
    return success();
  }

  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);

  if (!memAccess) {
    return memoryOp.emitOpError("invalid memory access specifier: ")
           << memAccessAttr;
  }

  if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                spirv::MemoryAccess::Aligned)) {
    if (!op->getAttr(memoryOp.getAlignmentAttrName())) {
      return memoryOp.emitOpError("missing alignment value");
    }
  } else {
    if (op->getAttr(memoryOp.getAlignmentAttrName())) {
      return memoryOp.emitOpError(
          "invalid alignment specification with non-aligned memory access "
          "specification");
    }
  }
  return success();
}

// TODO Make sure to merge this and the previous function into one template
// parameterized by memory access attribute name and alignment. Doing so now
// results in VS2017 in producing an internal error (at the call site) that's
// not detailed enough to understand what is happening.
template <typename MemoryOpTy>
static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
  // ODS checks for attributes values. Just need to verify that if the
  // memory-access attribute is Aligned, then the alignment attribute must be
  // present.
  auto *op = memoryOp.getOperation();
  auto memAccessAttr = op->getAttr(memoryOp.getSourceMemoryAccessAttrName());
  if (!memAccessAttr) {
    // Alignment attribute shouldn't be present if memory access attribute is
    // not present.
    if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
      return memoryOp.emitOpError(
          "invalid alignment specification without aligned memory access "
          "specification");
    }
    return success();
  }

  auto memAccess = llvm::cast<spirv::MemoryAccessAttr>(memAccessAttr);

  if (!memAccess) {
    return memoryOp.emitOpError("invalid memory access specifier: ")
           << memAccess;
  }

  if (spirv::bitEnumContainsAll(memAccess.getValue(),
                                spirv::MemoryAccess::Aligned)) {
    if (!op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
      return memoryOp.emitOpError("missing alignment value");
    }
  } else {
    if (op->getAttr(memoryOp.getSourceAlignmentAttrName())) {
      return memoryOp.emitOpError(
          "invalid alignment specification with non-aligned memory access "
          "specification");
    }
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spirv.AccessChainOp
//===----------------------------------------------------------------------===//

static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
  if (!ptrType) {
    emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
                       "to composite type, but provided ")
        << type;
    return nullptr;
  }

  auto resultType = ptrType.getPointeeType();
  auto resultStorageClass = ptrType.getStorageClass();
  int32_t index = 0;

  for (auto indexSSA : indices) {
    auto cType = llvm::dyn_cast<spirv::CompositeType>(resultType);
    if (!cType) {
      emitError(
          baseLoc,
          "'spirv.AccessChain' op cannot extract from non-composite type ")
          << resultType << " with index " << index;
      return nullptr;
    }
    index = 0;
    if (llvm::isa<spirv::StructType>(resultType)) {
      Operation *op = indexSSA.getDefiningOp();
      if (!op) {
        emitError(baseLoc, "'spirv.AccessChain' op index must be an "
                           "integer spirv.Constant to access "
                           "element of spirv.struct");
        return nullptr;
      }

      // TODO: this should be relaxed to allow
      // integer literals of other bitwidths.
      if (failed(spirv::extractValueFromConstOp(op, index))) {
        emitError(
            baseLoc,
            "'spirv.AccessChain' index must be an integer spirv.Constant to "
            "access element of spirv.struct, but provided ")
            << op->getName();
        return nullptr;
      }
      if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
        emitError(baseLoc, "'spirv.AccessChain' op index ")
            << index << " out of bounds for " << resultType;
        return nullptr;
      }
    }
    resultType = cType.getElementType(index);
  }
  return spirv::PointerType::get(resultType, resultStorageClass);
}

void AccessChainOp::build(OpBuilder &builder, OperationState &state,
                          Value basePtr, ValueRange indices) {
  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
  assert(type && "Unable to deduce return type based on basePtr and indices");
  build(builder, state, type, basePtr, indices);
}

ParseResult AccessChainOp::parse(OpAsmParser &parser, OperationState &result) {
  OpAsmParser::UnresolvedOperand ptrInfo;
  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
  Type type;
  auto loc = parser.getCurrentLocation();
  SmallVector<Type, 4> indicesTypes;

  if (parser.parseOperand(ptrInfo) ||
      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
      parser.parseColonType(type) ||
      parser.resolveOperand(ptrInfo, type, result.operands)) {
    return failure();
  }

  // Check that the provided indices list is not empty before parsing their
  // type list.
  if (indicesInfo.empty()) {
    return mlir::emitError(result.location,
                           "'spirv.AccessChain' op expected at "
                           "least one index ");
  }

  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
    return failure();

  // Check that the indices types list is not empty and that it has a one-to-one
  // mapping to the provided indices.
  if (indicesTypes.size() != indicesInfo.size()) {
    return mlir::emitError(
        result.location, "'spirv.AccessChain' op indices types' count must be "
                         "equal to indices info count");
  }

  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
    return failure();

  auto resultType = getElementPtrType(
      type, llvm::ArrayRef(result.operands).drop_front(), result.location);
  if (!resultType) {
    return failure();
  }

  result.addTypes(resultType);
  return success();
}

template <typename Op>
static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
  printer << ' ' << op.getBasePtr() << '[' << indices
          << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
}

void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
  printAccessChain(*this, getIndices(), printer);
}

template <typename Op>
static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
  auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
                                      indices, accessChainOp.getLoc());
  if (!resultType)
    return failure();

  auto providedResultType =
      llvm::dyn_cast<spirv::PointerType>(accessChainOp.getType());
  if (!providedResultType)
    return accessChainOp.emitOpError(
               "result type must be a pointer, but provided")
           << providedResultType;

  if (resultType != providedResultType)
    return accessChainOp.emitOpError("invalid result type: expected ")
           << resultType << ", but provided " << providedResultType;

  return success();
}

LogicalResult AccessChainOp::verify() {
  return verifyAccessChain(*this, getIndices());
}

//===----------------------------------------------------------------------===//
// spirv.LoadOp
//===----------------------------------------------------------------------===//

void LoadOp::build(OpBuilder &builder, OperationState &state, Value basePtr,
                   MemoryAccessAttr memoryAccess, IntegerAttr alignment) {
  auto ptrType = llvm::cast<spirv::PointerType>(basePtr.getType());
  build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
        alignment);
}

ParseResult LoadOp::parse(OpAsmParser &parser, OperationState &result) {
  // Parse the storage class specification
  spirv::StorageClass storageClass;
  OpAsmParser::UnresolvedOperand ptrInfo;
  Type elementType;
  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
      parseMemoryAccessAttributes<LoadOp>(parser, result) ||
      parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
      parser.parseType(elementType)) {
    return failure();
  }

  auto ptrType = spirv::PointerType::get(elementType, storageClass);
  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
    return failure();
  }

  result.addTypes(elementType);
  return success();
}

void LoadOp::print(OpAsmPrinter &printer) {
  SmallVector<StringRef, 4> elidedAttrs;
  StringRef sc = stringifyStorageClass(
      llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
  printer << " \"" << sc << "\" " << getPtr();

  printMemoryAccessAttribute(*this, printer, elidedAttrs);

  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
  printer << " : " << getType();
}

LogicalResult LoadOp::verify() {
  // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
  // type with fixed size; i.e., it cannot be, nor include, any
  // OpTypeRuntimeArray types."
  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
    return failure();
  }
  return verifyMemoryAccessAttribute(*this);
}

//===----------------------------------------------------------------------===//
// spirv.StoreOp
//===----------------------------------------------------------------------===//

ParseResult StoreOp::parse(OpAsmParser &parser, OperationState &result) {
  // Parse the storage class specification
  spirv::StorageClass storageClass;
  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
  auto loc = parser.getCurrentLocation();
  Type elementType;
  if (parseEnumStrAttr(storageClass, parser) ||
      parser.parseOperandList(operandInfo, 2) ||
      parseMemoryAccessAttributes<StoreOp>(parser, result) ||
      parser.parseColon() || parser.parseType(elementType)) {
    return failure();
  }

  auto ptrType = spirv::PointerType::get(elementType, storageClass);
  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
                             result.operands)) {
    return failure();
  }
  return success();
}

void StoreOp::print(OpAsmPrinter &printer) {
  SmallVector<StringRef, 4> elidedAttrs;
  StringRef sc = stringifyStorageClass(
      llvm::cast<spirv::PointerType>(getPtr().getType()).getStorageClass());
  printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();

  printMemoryAccessAttribute(*this, printer, elidedAttrs);

  printer << " : " << getValue().getType();
  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
}

LogicalResult StoreOp::verify() {
  // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
  // OpTypePointer whose Type operand is the same as the type of Object."
  if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
    return failure();
  return verifyMemoryAccessAttribute(*this);
}

//===----------------------------------------------------------------------===//
// spirv.CopyMemory
//===----------------------------------------------------------------------===//

void CopyMemoryOp::print(OpAsmPrinter &printer) {
  printer << ' ';

  StringRef targetStorageClass = stringifyStorageClass(
      llvm::cast<spirv::PointerType>(getTarget().getType()).getStorageClass());
  printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";

  StringRef sourceStorageClass = stringifyStorageClass(
      llvm::cast<spirv::PointerType>(getSource().getType()).getStorageClass());
  printer << " \"" << sourceStorageClass << "\" " << getSource();

  SmallVector<StringRef, 4> elidedAttrs;
  printMemoryAccessAttribute(*this, printer, elidedAttrs);
  printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
                                   getSourceMemoryAccess(),
                                   getSourceAlignment());

  printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);

  Type pointeeType =
      llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();
  printer << " : " << pointeeType;
}

ParseResult CopyMemoryOp::parse(OpAsmParser &parser, OperationState &result) {
  spirv::StorageClass targetStorageClass;
  OpAsmParser::UnresolvedOperand targetPtrInfo;

  spirv::StorageClass sourceStorageClass;
  OpAsmParser::UnresolvedOperand sourcePtrInfo;

  Type elementType;

  if (parseEnumStrAttr(targetStorageClass, parser) ||
      parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
      parseEnumStrAttr(sourceStorageClass, parser) ||
      parser.parseOperand(sourcePtrInfo) ||
      parseMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
    return failure();
  }

  if (!parser.parseOptionalComma()) {
    // Parse 2nd memory access attributes.
    if (parseSourceMemoryAccessAttributes<CopyMemoryOp>(parser, result)) {
      return failure();
    }
  }

  if (parser.parseColon() || parser.parseType(elementType))
    return failure();

  if (parser.parseOptionalAttrDict(result.attributes))
    return failure();

  auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
  auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);

  if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
      parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
    return failure();
  }

  return success();
}

LogicalResult CopyMemoryOp::verify() {
  Type targetType =
      llvm::cast<spirv::PointerType>(getTarget().getType()).getPointeeType();

  Type sourceType =
      llvm::cast<spirv::PointerType>(getSource().getType()).getPointeeType();

  if (targetType != sourceType)
    return emitOpError("both operands must be pointers to the same type");

  if (failed(verifyMemoryAccessAttribute(*this)))
    return failure();

  // TODO - According to the spec:
  //
  // If two masks are present, the first applies to Target and cannot include
  // MakePointerVisible, and the second applies to Source and cannot include
  // MakePointerAvailable.
  //
  // Add such verification here.

  return verifySourceMemoryAccessAttribute(*this);
}

static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
                                             OpAsmParser &parser,
                                             OperationState &state) {
  OpAsmParser::UnresolvedOperand ptrInfo;
  SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
  Type type;
  auto loc = parser.getCurrentLocation();
  SmallVector<Type, 4> indicesTypes;

  if (parser.parseOperand(ptrInfo) ||
      parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
      parser.parseColonType(type) ||
      parser.resolveOperand(ptrInfo, type, state.operands))
    return failure();

  // Check that the provided indices list is not empty before parsing their
  // type list.
  if (indicesInfo.empty())
    return emitError(state.location) << opName << " expected element";

  if (parser.parseComma() || parser.parseTypeList(indicesTypes))
    return failure();

  // Check that the indices types list is not empty and that it has a one-to-one
  // mapping to the provided indices.
  if (indicesTypes.size() != indicesInfo.size())
    return emitError(state.location)
           << opName
           << " indices types' count must be equal to indices info count";

  if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
    return failure();

  auto resultType = getElementPtrType(
      type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
  if (!resultType)
    return failure();

  state.addTypes(resultType);
  return success();
}

template <typename Op>
static auto concatElemAndIndices(Op op) {
  SmallVector<Value> ret(op.getIndices().size() + 1);
  ret[0] = op.getElement();
  llvm::copy(op.getIndices(), ret.begin() + 1);
  return ret;
}

//===----------------------------------------------------------------------===//
// spirv.InBoundsPtrAccessChainOp
//===----------------------------------------------------------------------===//

void InBoundsPtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
                                     Value basePtr, Value element,
                                     ValueRange indices) {
  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
  assert(type && "Unable to deduce return type based on basePtr and indices");
  build(builder, state, type, basePtr, element, indices);
}

ParseResult InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
                                            OperationState &result) {
  return parsePtrAccessChainOpImpl(
      spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
}

void InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
  printAccessChain(*this, concatElemAndIndices(*this), printer);
}

LogicalResult InBoundsPtrAccessChainOp::verify() {
  return verifyAccessChain(*this, getIndices());
}

//===----------------------------------------------------------------------===//
// spirv.PtrAccessChainOp
//===----------------------------------------------------------------------===//

void PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
                             Value basePtr, Value element, ValueRange indices) {
  auto type = getElementPtrType(basePtr.getType(), indices, state.location);
  assert(type && "Unable to deduce return type based on basePtr and indices");
  build(builder, state, type, basePtr, element, indices);
}

ParseResult PtrAccessChainOp::parse(OpAsmParser &parser,
                                    OperationState &result) {
  return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
                                   parser, result);
}

void PtrAccessChainOp::print(OpAsmPrinter &printer) {
  printAccessChain(*this, concatElemAndIndices(*this), printer);
}

LogicalResult PtrAccessChainOp::verify() {
  return verifyAccessChain(*this, getIndices());
}

//===----------------------------------------------------------------------===//
// spirv.Variable
//===----------------------------------------------------------------------===//

ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
  // Parse optional initializer
  std::optional<OpAsmParser::UnresolvedOperand> initInfo;
  if (succeeded(parser.parseOptionalKeyword("init"))) {
    initInfo = OpAsmParser::UnresolvedOperand();
    if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
        parser.parseRParen())
      return failure();
  }

  if (parseVariableDecorations(parser, result)) {
    return failure();
  }

  // Parse result pointer type
  Type type;
  if (parser.parseColon())
    return failure();
  auto loc = parser.getCurrentLocation();
  if (parser.parseType(type))
    return failure();

  auto ptrType = llvm::dyn_cast<spirv::PointerType>(type);
  if (!ptrType)
    return parser.emitError(loc, "expected spirv.ptr type");
  result.addTypes(ptrType);

  // Resolve the initializer operand
  if (initInfo) {
    if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
                              result.operands))
      return failure();
  }

  auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
      ptrType.getStorageClass());
  result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);

  return success();
}

void VariableOp::print(OpAsmPrinter &printer) {
  SmallVector<StringRef, 4> elidedAttrs{
      spirv::attributeName<spirv::StorageClass>()};
  // Print optional initializer
  if (getNumOperands() != 0)
    printer << " init(" << getInitializer() << ")";

  printVariableDecorations(*this, printer, elidedAttrs);
  printer << " : " << getType();
}

LogicalResult VariableOp::verify() {
  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
  // object. It cannot be Generic. It must be the same as the Storage Class
  // operand of the Result Type."
  if (getStorageClass() != spirv::StorageClass::Function) {
    return emitOpError(
        "can only be used to model function-level variables. Use "
        "spirv.GlobalVariable for module-level variables.");
  }

  auto pointerType = llvm::cast<spirv::PointerType>(getPointer().getType());
  if (getStorageClass() != pointerType.getStorageClass())
    return emitOpError(
        "storage class must match result pointer's storage class");

  if (getNumOperands() != 0) {
    // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
    // a global (module scope) OpVariable instruction".
    auto *initOp = getOperand(0).getDefiningOp();
    if (!initOp || !isa<spirv::ConstantOp,    // for normal constant
                        spirv::ReferenceOfOp, // for spec constant
                        spirv::AddressOfOp>(initOp))
      return emitOpError("initializer must be the result of a "
                         "constant or spirv.GlobalVariable op");
  }

  auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
    return op->getAttr(
        llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
  };

  // TODO: generate these strings using ODS.
  for (auto decoration :
       {spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
        spirv::Decoration::BuiltIn}) {
    if (auto attr = getDecorationAttr(decoration))
      return emitOpError("cannot have '")
             << llvm::convertToSnakeFromCamelCase(
                    stringifyDecoration(decoration))
             << "' attribute (only allowed in spirv.GlobalVariable)";
  }

  // From SPV_KHR_physical_storage_buffer:
  // > If an OpVariable's pointee type is a pointer (or array of pointers) in
  // > PhysicalStorageBuffer storage class, then the variable must be decorated
  // > with exactly one of AliasedPointer or RestrictPointer.
  auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
  if (!pointeePtrType) {
    if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
      pointeePtrType =
          dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
    }
  }

  if (pointeePtrType && pointeePtrType.getStorageClass() ==
                            spirv::StorageClass::PhysicalStorageBuffer) {
    bool hasAliasedPtr =
        getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
    bool hasRestrictPtr =
        getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;

    if (!hasAliasedPtr && !hasRestrictPtr)
      return emitOpError() << " with physical buffer pointer must be decorated "
                              "either 'AliasedPointer' or 'RestrictPointer'";

    if (hasAliasedPtr && hasRestrictPtr)
      return emitOpError()
             << " with physical buffer pointer must have exactly one "
                "aliasing decoration";
  }

  return success();
}

} // namespace mlir::spirv