//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"

#include "SPIRVOpUtils.h"
#include "SPIRVParsingUtils.h"

#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include <cassert>
#include <numeric>
#include <optional>
#include <type_traits>

using namespace mlir;
using namespace mlir::spirv::AttrNames;

//===----------------------------------------------------------------------===//
// Common utility functions
//===----------------------------------------------------------------------===//

LogicalResult spirv::extractValueFromConstOp(Operation *op, int32_t &value) {
  auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
  if (!constOp) {
    return failure();
  }
  auto valueAttr = constOp.getValue();
  auto integerValueAttr = llvm::dyn_cast<IntegerAttr>(valueAttr);
  if (!integerValueAttr) {
    return failure();
  }

  if (integerValueAttr.getType().isSignlessInteger())
    value = integerValueAttr.getInt();
  else
    value = integerValueAttr.getSInt();

  return success();
}

LogicalResult
spirv::verifyMemorySemantics(Operation *op,
                             spirv::MemorySemantics memorySemantics) {
  // According to the SPIR-V specification:
  // "Despite being a mask and allowing multiple bits to be combined, it is
  // invalid for more than one of these four bits to be set: Acquire, Release,
  // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
  // Release semantics is done by setting the AcquireRelease bit, not by setting
  // two bits."
  auto atMostOneInSet = spirv::MemorySemantics::Acquire |
                        spirv::MemorySemantics::Release |
                        spirv::MemorySemantics::AcquireRelease |
                        spirv::MemorySemantics::SequentiallyConsistent;

  auto bitCount =
      llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
  if (bitCount > 1) {
    return op->emitError(
        "expected at most one of these four memory constraints "
        "to be set: `Acquire`, `Release`,"
        "`AcquireRelease` or `SequentiallyConsistent`");
  }
  return success();
}

void spirv::printVariableDecorations(Operation *op, OpAsmPrinter &printer,
                                     SmallVectorImpl<StringRef> &elidedAttrs) {
  // Print optional descriptor binding
  auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::DescriptorSet));
  auto bindingName = llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::Binding));
  auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
  auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
  if (descriptorSet && binding) {
    elidedAttrs.push_back(descriptorSetName);
    elidedAttrs.push_back(bindingName);
    printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
            << ")";
  }

  // Print BuiltIn attribute if present
  auto builtInName = llvm::convertToSnakeFromCamelCase(
      stringifyDecoration(spirv::Decoration::BuiltIn));
  if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
    printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
    elidedAttrs.push_back(builtInName);
  }

  printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
}

static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
                                                   OperationState &result) {
  SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
  Type type;
  // If the operand list is in-between parentheses, then we have a generic form.
  // (see the fallback in `printOneResultOp`).
  SMLoc loc = parser.getCurrentLocation();
  if (!parser.parseOptionalLParen()) {
    if (parser.parseOperandList(ops) || parser.parseRParen() ||
        parser.parseOptionalAttrDict(result.attributes) ||
        parser.parseColon() || parser.parseType(type))
      return failure();
    auto fnType = llvm::dyn_cast<FunctionType>(type);
    if (!fnType) {
      parser.emitError(loc, "expected function type");
      return failure();
    }
    if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
      return failure();
    result.addTypes(fnType.getResults());
    return success();
  }
  return failure(parser.parseOperandList(ops) ||
                 parser.parseOptionalAttrDict(result.attributes) ||
                 parser.parseColonType(type) ||
                 parser.resolveOperands(ops, type, result.operands) ||
                 parser.addTypeToList(type, result.types));
}

static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
  assert(op->getNumResults() == 1 && "op should have one result");

  // If not all the operand and result types are the same, just use the
  // generic assembly form to avoid omitting information in printing.
  auto resultType = op->getResult(0).getType();
  if (llvm::any_of(op->getOperandTypes(),
                   [&](Type type) { return type != resultType; })) {
    p.printGenericOp(op, /*printOpName=*/false);
    return;
  }

  p << ' ';
  p.printOperands(op->getOperands());
  p.printOptionalAttrDict(op->getAttrs());
  // Now we can output only one type for all operands and the result.
  p << " : " << resultType;
}

template <typename Op>
static LogicalResult verifyImageOperands(Op imageOp,
                                         spirv::ImageOperandsAttr attr,
                                         Operation::operand_range operands) {
  if (!attr) {
    if (operands.empty())
      return success();

    return imageOp.emitError("the Image Operands should encode what operands "
                             "follow, as per Image Operands");
  }

  // TODO: Add the validation rules for the following Image Operands.
  spirv::ImageOperands noSupportOperands =
      spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
      spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
      spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
      spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
      spirv::ImageOperands::MakeTexelAvailable |
      spirv::ImageOperands::MakeTexelVisible |
      spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;

  if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
    llvm_unreachable("unimplemented operands of Image Operands");

  return success();
}

template <typename BlockReadWriteOpTy>
static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
                                                        Value ptr, Value val) {
  auto valType = val.getType();
  if (auto valVecTy = llvm::dyn_cast<VectorType>(valType))
    valType = valVecTy.getElementType();

  if (valType !=
      llvm::cast<spirv::PointerType>(ptr.getType()).getPointeeType()) {
    return op.emitOpError("mismatch in result type and pointer type");
  }
  return success();
}

/// Walks the given type hierarchy with the given indices, potentially down
/// to component granularity, to select an element type. Returns null type and
/// emits errors with the given loc on failure.
static Type
getElementType(Type type, ArrayRef<int32_t> indices,
               function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
  if (indices.empty()) {
    emitErrorFn("expected at least one index for spirv.CompositeExtract");
    return nullptr;
  }

  for (auto index : indices) {
    if (auto cType = llvm::dyn_cast<spirv::CompositeType>(type)) {
      if (cType.hasCompileTimeKnownNumElements() &&
          (index < 0 ||
           static_cast<uint64_t>(index) >= cType.getNumElements())) {
        emitErrorFn("index ") << index << " out of bounds for " << type;
        return nullptr;
      }
      type = cType.getElementType(index);
    } else {
      emitErrorFn("cannot extract from non-composite type ")
          << type << " with index " << index;
      return nullptr;
    }
  }
  return type;
}

static Type
getElementType(Type type, Attribute indices,
               function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(indices);
  if (!indicesArrayAttr) {
    emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
    return nullptr;
  }
  if (indicesArrayAttr.empty()) {
    emitErrorFn("expected at least one index for spirv.CompositeExtract");
    return nullptr;
  }

  SmallVector<int32_t, 2> indexVals;
  for (auto indexAttr : indicesArrayAttr) {
    auto indexIntAttr = llvm::dyn_cast<IntegerAttr>(indexAttr);
    if (!indexIntAttr) {
      emitErrorFn("expected an 32-bit integer for index, but found '")
          << indexAttr << "'";
      return nullptr;
    }
    indexVals.push_back(indexIntAttr.getInt());
  }
  return getElementType(type, indexVals, emitErrorFn);
}

static Type getElementType(Type type, Attribute indices, Location loc) {
  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
    return ::mlir::emitError(loc, err);
  };
  return getElementType(type, indices, errorFn);
}

static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
                           SMLoc loc) {
  auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
    return parser.emitError(loc, err);
  };
  return getElementType(type, indices, errorFn);
}

template <typename ExtendedBinaryOp>
static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
  auto resultType = llvm::cast<spirv::StructType>(op.getType());
  if (resultType.getNumElements() != 2)
    return op.emitOpError("expected result struct type containing two members");

  if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
                        resultType.getElementType(0),
                        resultType.getElementType(1)}))
    return op.emitOpError(
        "expected all operand types and struct member types are the same");

  return success();
}

static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
                                                   OperationState &result) {
  SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
  if (parser.parseOptionalAttrDict(result.attributes) ||
      parser.parseOperandList(operands) || parser.parseColon())
    return failure();

  Type resultType;
  SMLoc loc = parser.getCurrentLocation();
  if (parser.parseType(resultType))
    return failure();

  auto structType = llvm::dyn_cast<spirv::StructType>(resultType);
  if (!structType || structType.getNumElements() != 2)
    return parser.emitError(loc, "expected spirv.struct type with two members");

  SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
  if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
    return failure();

  result.addTypes(resultType);
  return success();
}

static void printArithmeticExtendedBinaryOp(Operation *op,
                                            OpAsmPrinter &printer) {
  printer << ' ';
  printer.printOptionalAttrDict(op->getAttrs());
  printer.printOperands(op->getOperands());
  printer << " : " << op->getResultTypes().front();
}

static LogicalResult verifyShiftOp(Operation *op) {
  if (op->getOperand(0).getType() != op->getResult(0).getType()) {
    return op->emitError("expected the same type for the first operand and "
                         "result, but provided ")
           << op->getOperand(0).getType() << " and "
           << op->getResult(0).getType();
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spirv.mlir.addressof
//===----------------------------------------------------------------------===//

void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
                               spirv::GlobalVariableOp var) {
  build(builder, state, var.getType(), SymbolRefAttr::get(var));
}

LogicalResult spirv::AddressOfOp::verify() {
  auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
      SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
                                           getVariableAttr()));
  if (!varOp) {
    return emitOpError("expected spirv.GlobalVariable symbol");
  }
  if (getPointer().getType() != varOp.getType()) {
    return emitOpError(
        "result type mismatch with the referenced global variable's type");
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spirv.CompositeConstruct
//===----------------------------------------------------------------------===//

LogicalResult spirv::CompositeConstructOp::verify() {
  operand_range constituents = this->getConstituents();

  // There are 4 cases with varying verification rules:
  // 1. Cooperative Matrices (1 constituent)
  // 2. Structs (1 constituent for each member)
  // 3. Arrays (1 constituent for each array element)
  // 4. Vectors (1 constituent (sub-)element for each vector element)

  auto coopElementType =
      llvm::TypeSwitch<Type, Type>(getType())
          .Case<spirv::CooperativeMatrixType, spirv::JointMatrixINTELType>(
              [](auto coopType) { return coopType.getElementType(); })
          .Default([](Type) { return nullptr; });

  // Case 1. -- matrices.
  if (coopElementType) {
    if (constituents.size() != 1)
      return emitOpError("has incorrect number of operands: expected ")
             << "1, but provided " << constituents.size();
    if (coopElementType != constituents.front().getType())
      return emitOpError("operand type mismatch: expected operand type ")
             << coopElementType << ", but provided "
             << constituents.front().getType();
    return success();
  }

  // Case 2./3./4. -- number of constituents matches the number of elements.
  auto cType = llvm::cast<spirv::CompositeType>(getType());
  if (constituents.size() == cType.getNumElements()) {
    for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
      if (constituents[index].getType() != cType.getElementType(index)) {
        return emitOpError("operand type mismatch: expected operand type ")
               << cType.getElementType(index) << ", but provided "
               << constituents[index].getType();
      }
    }
    return success();
  }

  // Case 4. -- check that all constituents add up tp the expected vector type.
  auto resultType = llvm::dyn_cast<VectorType>(cType);
  if (!resultType)
    return emitOpError(
        "expected to return a vector or cooperative matrix when the number of "
        "constituents is less than what the result needs");

  SmallVector<unsigned> sizes;
  for (Value component : constituents) {
    if (!llvm::isa<VectorType>(component.getType()) &&
        !component.getType().isIntOrFloat())
      return emitOpError("operand type mismatch: expected operand to have "
                         "a scalar or vector type, but provided ")
             << component.getType();

    Type elementType = component.getType();
    if (auto vectorType = llvm::dyn_cast<VectorType>(component.getType())) {
      sizes.push_back(vectorType.getNumElements());
      elementType = vectorType.getElementType();
    } else {
      sizes.push_back(1);
    }

    if (elementType != resultType.getElementType())
      return emitOpError("operand element type mismatch: expected to be ")
             << resultType.getElementType() << ", but provided " << elementType;
  }
  unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
  if (totalCount != cType.getNumElements())
    return emitOpError("has incorrect number of operands: expected ")
           << cType.getNumElements() << ", but provided " << totalCount;
  return success();
}

//===----------------------------------------------------------------------===//
// spirv.CompositeExtractOp
//===----------------------------------------------------------------------===//

void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
                                      Value composite,
                                      ArrayRef<int32_t> indices) {
  auto indexAttr = builder.getI32ArrayAttr(indices);
  auto elementType =
      getElementType(composite.getType(), indexAttr, state.location);
  if (!elementType) {
    return;
  }
  build(builder, state, elementType, composite, indexAttr);
}

ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
                                             OperationState &result) {
  OpAsmParser::UnresolvedOperand compositeInfo;
  Attribute indicesAttr;
  StringRef indicesAttrName =
      spirv::CompositeExtractOp::getIndicesAttrName(result.name);
  Type compositeType;
  SMLoc attrLocation;

  if (parser.parseOperand(compositeInfo) ||
      parser.getCurrentLocation(&attrLocation) ||
      parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
      parser.parseColonType(compositeType) ||
      parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
    return failure();
  }

  Type resultType =
      getElementType(compositeType, indicesAttr, parser, attrLocation);
  if (!resultType) {
    return failure();
  }
  result.addTypes(resultType);
  return success();
}

void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
  printer << ' ' << getComposite() << getIndices() << " : "
          << getComposite().getType();
}

LogicalResult spirv::CompositeExtractOp::verify() {
  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
  auto resultType =
      getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
  if (!resultType)
    return failure();

  if (resultType != getType()) {
    return emitOpError("invalid result type: expected ")
           << resultType << " but provided " << getType();
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.CompositeInsert
//===----------------------------------------------------------------------===//

void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
                                     Value object, Value composite,
                                     ArrayRef<int32_t> indices) {
  auto indexAttr = builder.getI32ArrayAttr(indices);
  build(builder, state, composite.getType(), object, composite, indexAttr);
}

ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
                                            OperationState &result) {
  SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
  Type objectType, compositeType;
  Attribute indicesAttr;
  StringRef indicesAttrName =
      spirv::CompositeInsertOp::getIndicesAttrName(result.name);
  auto loc = parser.getCurrentLocation();

  return failure(
      parser.parseOperandList(operands, 2) ||
      parser.parseAttribute(indicesAttr, indicesAttrName, result.attributes) ||
      parser.parseColonType(objectType) ||
      parser.parseKeywordType("into", compositeType) ||
      parser.resolveOperands(operands, {objectType, compositeType}, loc,
                             result.operands) ||
      parser.addTypesToList(compositeType, result.types));
}

LogicalResult spirv::CompositeInsertOp::verify() {
  auto indicesArrayAttr = llvm::dyn_cast<ArrayAttr>(getIndices());
  auto objectType =
      getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
  if (!objectType)
    return failure();

  if (objectType != getObject().getType()) {
    return emitOpError("object operand type should be ")
           << objectType << ", but found " << getObject().getType();
  }

  if (getComposite().getType() != getType()) {
    return emitOpError("result type should be the same as "
                       "the composite type, but found ")
           << getComposite().getType() << " vs " << getType();
  }

  return success();
}

void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
  printer << " " << getObject() << ", " << getComposite() << getIndices()
          << " : " << getObject().getType() << " into "
          << getComposite().getType();
}

//===----------------------------------------------------------------------===//
// spirv.Constant
//===----------------------------------------------------------------------===//

ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
                                     OperationState &result) {
  Attribute value;
  StringRef valueAttrName = spirv::ConstantOp::getValueAttrName(result.name);
  if (parser.parseAttribute(value, valueAttrName, result.attributes))
    return failure();

  Type type = NoneType::get(parser.getContext());
  if (auto typedAttr = llvm::dyn_cast<TypedAttr>(value))
    type = typedAttr.getType();
  if (llvm::isa<NoneType, TensorType>(type)) {
    if (parser.parseColonType(type))
      return failure();
  }

  return parser.addTypeToList(type, result.types);
}

void spirv::ConstantOp::print(OpAsmPrinter &printer) {
  printer << ' ' << getValue();
  if (llvm::isa<spirv::ArrayType>(getType()))
    printer << " : " << getType();
}

static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
                                        Type opType) {
  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
    auto valueType = llvm::cast<TypedAttr>(value).getType();
    if (valueType != opType)
      return op.emitOpError("result type (")
             << opType << ") does not match value type (" << valueType << ")";
    return success();
  }
  if (llvm::isa<DenseIntOrFPElementsAttr, SparseElementsAttr>(value)) {
    auto valueType = llvm::cast<TypedAttr>(value).getType();
    if (valueType == opType)
      return success();
    auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
    auto shapedType = llvm::dyn_cast<ShapedType>(valueType);
    if (!arrayType)
      return op.emitOpError("result or element type (")
             << opType << ") does not match value type (" << valueType
             << "), must be the same or spirv.array";

    int numElements = arrayType.getNumElements();
    auto opElemType = arrayType.getElementType();
    while (auto t = llvm::dyn_cast<spirv::ArrayType>(opElemType)) {
      numElements *= t.getNumElements();
      opElemType = t.getElementType();
    }
    if (!opElemType.isIntOrFloat())
      return op.emitOpError("only support nested array result type");

    auto valueElemType = shapedType.getElementType();
    if (valueElemType != opElemType) {
      return op.emitOpError("result element type (")
             << opElemType << ") does not match value element type ("
             << valueElemType << ")";
    }

    if (numElements != shapedType.getNumElements()) {
      return op.emitOpError("result number of elements (")
             << numElements << ") does not match value number of elements ("
             << shapedType.getNumElements() << ")";
    }
    return success();
  }
  if (auto arrayAttr = llvm::dyn_cast<ArrayAttr>(value)) {
    auto arrayType = llvm::dyn_cast<spirv::ArrayType>(opType);
    if (!arrayType)
      return op.emitOpError(
          "must have spirv.array result type for array value");
    Type elemType = arrayType.getElementType();
    for (Attribute element : arrayAttr.getValue()) {
      // Verify array elements recursively.
      if (failed(verifyConstantType(op, element, elemType)))
        return failure();
    }
    return success();
  }
  return op.emitOpError("cannot have attribute: ") << value;
}

LogicalResult spirv::ConstantOp::verify() {
  // ODS already generates checks to make sure the result type is valid. We just
  // need to additionally check that the value's attribute type is consistent
  // with the result type.
  return verifyConstantType(*this, getValueAttr(), getType());
}

bool spirv::ConstantOp::isBuildableWith(Type type) {
  // Must be valid SPIR-V type first.
  if (!llvm::isa<spirv::SPIRVType>(type))
    return false;

  if (isa<SPIRVDialect>(type.getDialect())) {
    // TODO: support constant struct
    return llvm::isa<spirv::ArrayType>(type);
  }

  return true;
}

spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
                                             OpBuilder &builder) {
  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
    unsigned width = intType.getWidth();
    if (width == 1)
      return builder.create<spirv::ConstantOp>(loc, type,
                                               builder.getBoolAttr(false));
    return builder.create<spirv::ConstantOp>(
        loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
  }
  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
    return builder.create<spirv::ConstantOp>(
        loc, type, builder.getFloatAttr(floatType, 0.0));
  }
  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
    Type elemType = vectorType.getElementType();
    if (llvm::isa<IntegerType>(elemType)) {
      return builder.create<spirv::ConstantOp>(
          loc, type,
          DenseElementsAttr::get(vectorType,
                                 IntegerAttr::get(elemType, 0).getValue()));
    }
    if (llvm::isa<FloatType>(elemType)) {
      return builder.create<spirv::ConstantOp>(
          loc, type,
          DenseFPElementsAttr::get(vectorType,
                                   FloatAttr::get(elemType, 0.0).getValue()));
    }
  }

  llvm_unreachable("unimplemented types for ConstantOp::getZero()");
}

spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
                                            OpBuilder &builder) {
  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
    unsigned width = intType.getWidth();
    if (width == 1)
      return builder.create<spirv::ConstantOp>(loc, type,
                                               builder.getBoolAttr(true));
    return builder.create<spirv::ConstantOp>(
        loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
  }
  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
    return builder.create<spirv::ConstantOp>(
        loc, type, builder.getFloatAttr(floatType, 1.0));
  }
  if (auto vectorType = llvm::dyn_cast<VectorType>(type)) {
    Type elemType = vectorType.getElementType();
    if (llvm::isa<IntegerType>(elemType)) {
      return builder.create<spirv::ConstantOp>(
          loc, type,
          DenseElementsAttr::get(vectorType,
                                 IntegerAttr::get(elemType, 1).getValue()));
    }
    if (llvm::isa<FloatType>(elemType)) {
      return builder.create<spirv::ConstantOp>(
          loc, type,
          DenseFPElementsAttr::get(vectorType,
                                   FloatAttr::get(elemType, 1.0).getValue()));
    }
  }

  llvm_unreachable("unimplemented types for ConstantOp::getOne()");
}

void mlir::spirv::ConstantOp::getAsmResultNames(
    llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
  Type type = getType();

  SmallString<32> specialNameBuffer;
  llvm::raw_svector_ostream specialName(specialNameBuffer);
  specialName << "cst";

  IntegerType intTy = llvm::dyn_cast<IntegerType>(type);

  if (IntegerAttr intCst = llvm::dyn_cast<IntegerAttr>(getValue())) {
    if (intTy && intTy.getWidth() == 1) {
      return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
    }

    if (intTy.isSignless()) {
      specialName << intCst.getInt();
    } else if (intTy.isUnsigned()) {
      specialName << intCst.getUInt();
    } else {
      specialName << intCst.getSInt();
    }
  }

  if (intTy || llvm::isa<FloatType>(type)) {
    specialName << '_' << type;
  }

  if (auto vecType = llvm::dyn_cast<VectorType>(type)) {
    specialName << "_vec_";
    specialName << vecType.getDimSize(0);

    Type elementType = vecType.getElementType();

    if (llvm::isa<IntegerType>(elementType) ||
        llvm::isa<FloatType>(elementType)) {
      specialName << "x" << elementType;
    }
  }

  setNameFn(getResult(), specialName.str());
}

void mlir::spirv::AddressOfOp::getAsmResultNames(
    llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
  SmallString<32> specialNameBuffer;
  llvm::raw_svector_ostream specialName(specialNameBuffer);
  specialName << getVariable() << "_addr";
  setNameFn(getResult(), specialName.str());
}

//===----------------------------------------------------------------------===//
// spirv.ControlBarrierOp
//===----------------------------------------------------------------------===//

LogicalResult spirv::ControlBarrierOp::verify() {
  return verifyMemorySemantics(getOperation(), getMemorySemantics());
}

//===----------------------------------------------------------------------===//
// spirv.EntryPoint
//===----------------------------------------------------------------------===//

void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
                                spirv::ExecutionModel executionModel,
                                spirv::FuncOp function,
                                ArrayRef<Attribute> interfaceVars) {
  build(builder, state,
        spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
        SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
}

ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
                                       OperationState &result) {
  spirv::ExecutionModel execModel;
  SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers;
  SmallVector<Type, 0> idTypes;
  SmallVector<Attribute, 4> interfaceVars;

  FlatSymbolRefAttr fn;
  if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
      parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
    return failure();
  }

  if (!parser.parseOptionalComma()) {
    // Parse the interface variables
    if (parser.parseCommaSeparatedList([&]() -> ParseResult {
          // The name of the interface variable attribute isnt important
          FlatSymbolRefAttr var;
          NamedAttrList attrs;
          if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
            return failure();
          interfaceVars.push_back(var);
          return success();
        }))
      return failure();
  }
  result.addAttribute(spirv::EntryPointOp::getInterfaceAttrName(result.name),
                      parser.getBuilder().getArrayAttr(interfaceVars));
  return success();
}

void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
  printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
  printer.printSymbolName(getFn());
  auto interfaceVars = getInterface().getValue();
  if (!interfaceVars.empty()) {
    printer << ", ";
    llvm::interleaveComma(interfaceVars, printer);
  }
}

LogicalResult spirv::EntryPointOp::verify() {
  // Checks for fn and interface symbol reference are done in spirv::ModuleOp
  // verification.
  return success();
}

//===----------------------------------------------------------------------===//
// spirv.ExecutionMode
//===----------------------------------------------------------------------===//

void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
                                   spirv::FuncOp function,
                                   spirv::ExecutionMode executionMode,
                                   ArrayRef<int32_t> params) {
  build(builder, state, SymbolRefAttr::get(function),
        spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
        builder.getI32ArrayAttr(params));
}

ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
                                          OperationState &result) {
  spirv::ExecutionMode execMode;
  Attribute fn;
  if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
      parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
    return failure();
  }

  SmallVector<int32_t, 4> values;
  Type i32Type = parser.getBuilder().getIntegerType(32);
  while (!parser.parseOptionalComma()) {
    NamedAttrList attr;
    Attribute value;
    if (parser.parseAttribute(value, i32Type, "value", attr)) {
      return failure();
    }
    values.push_back(llvm::cast<IntegerAttr>(value).getInt());
  }
  StringRef valuesAttrName =
      spirv::ExecutionModeOp::getValuesAttrName(result.name);
  result.addAttribute(valuesAttrName,
                      parser.getBuilder().getI32ArrayAttr(values));
  return success();
}

void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
  printer << " ";
  printer.printSymbolName(getFn());
  printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
  auto values = this->getValues();
  if (values.empty())
    return;
  printer << ", ";
  llvm::interleaveComma(values, printer, [&](Attribute a) {
    printer << llvm::cast<IntegerAttr>(a).getInt();
  });
}

//===----------------------------------------------------------------------===//
// spirv.func
//===----------------------------------------------------------------------===//

ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
  SmallVector<OpAsmParser::Argument> entryArgs;
  SmallVector<DictionaryAttr> resultAttrs;
  SmallVector<Type> resultTypes;
  auto &builder = parser.getBuilder();

  // Parse the name as a symbol.
  StringAttr nameAttr;
  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                             result.attributes))
    return failure();

  // Parse the function signature.
  bool isVariadic = false;
  if (function_interface_impl::parseFunctionSignature(
          parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
          resultAttrs))
    return failure();

  SmallVector<Type> argTypes;
  for (auto &arg : entryArgs)
    argTypes.push_back(arg.type);
  auto fnType = builder.getFunctionType(argTypes, resultTypes);
  result.addAttribute(getFunctionTypeAttrName(result.name),
                      TypeAttr::get(fnType));

  // Parse the optional function control keyword.
  spirv::FunctionControl fnControl;
  if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
    return failure();

  // If additional attributes are present, parse them.
  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
    return failure();

  // Add the attributes to the function arguments.
  assert(resultAttrs.size() == resultTypes.size());
  function_interface_impl::addArgAndResultAttrs(
      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
      getResAttrsAttrName(result.name));

  // Parse the optional function body.
  auto *body = result.addRegion();
  OptionalParseResult parseResult =
      parser.parseOptionalRegion(*body, entryArgs);
  return failure(parseResult.has_value() && failed(*parseResult));
}

void spirv::FuncOp::print(OpAsmPrinter &printer) {
  // Print function name, signature, and control.
  printer << " ";
  printer.printSymbolName(getSymName());
  auto fnType = getFunctionType();
  function_interface_impl::printFunctionSignature(
      printer, *this, fnType.getInputs(),
      /*isVariadic=*/false, fnType.getResults());
  printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
          << "\"";
  function_interface_impl::printFunctionAttributes(
      printer, *this,
      {spirv::attributeName<spirv::FunctionControl>(),
       getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
       getFunctionControlAttrName()});

  // Print the body if this is not an external function.
  Region &body = this->getBody();
  if (!body.empty()) {
    printer << ' ';
    printer.printRegion(body, /*printEntryBlockArgs=*/false,
                        /*printBlockTerminators=*/true);
  }
}

LogicalResult spirv::FuncOp::verifyType() {
  FunctionType fnType = getFunctionType();
  if (fnType.getNumResults() > 1)
    return emitOpError("cannot have more than one result");

  auto hasDecorationAttr = [&](spirv::Decoration decoration,
                               unsigned argIndex) {
    auto func = llvm::cast<FunctionOpInterface>(getOperation());
    for (auto argAttr : cast<FunctionOpInterface>(func).getArgAttrs(argIndex)) {
      if (argAttr.getName() != spirv::DecorationAttr::name)
        continue;
      if (auto decAttr = dyn_cast<spirv::DecorationAttr>(argAttr.getValue()))
        return decAttr.getValue() == decoration;
    }
    return false;
  };

  for (unsigned i = 0, e = this->getNumArguments(); i != e; ++i) {
    Type param = fnType.getInputs()[i];
    auto inputPtrType = dyn_cast<spirv::PointerType>(param);
    if (!inputPtrType)
      continue;

    auto pointeePtrType =
        dyn_cast<spirv::PointerType>(inputPtrType.getPointeeType());
    if (pointeePtrType) {
      // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
      // > If an OpFunctionParameter is a pointer (or contains a pointer)
      // > and the type it points to is a pointer in the PhysicalStorageBuffer
      // > storage class, the function parameter must be decorated with exactly
      // > one of AliasedPointer or RestrictPointer.
      if (pointeePtrType.getStorageClass() !=
          spirv::StorageClass::PhysicalStorageBuffer)
        continue;

      bool hasAliasedPtr =
          hasDecorationAttr(spirv::Decoration::AliasedPointer, i);
      bool hasRestrictPtr =
          hasDecorationAttr(spirv::Decoration::RestrictPointer, i);
      if (!hasAliasedPtr && !hasRestrictPtr)
        return emitOpError()
               << "with a pointer points to a physical buffer pointer must "
                  "be decorated either 'AliasedPointer' or 'RestrictPointer'";
      continue;
    }
    // SPIR-V spec, from SPV_KHR_physical_storage_buffer:
    // > If an OpFunctionParameter is a pointer (or contains a pointer) in
    // > the PhysicalStorageBuffer storage class, the function parameter must
    // > be decorated with exactly one of Aliased or Restrict.
    if (auto pointeeArrayType =
            dyn_cast<spirv::ArrayType>(inputPtrType.getPointeeType())) {
      pointeePtrType =
          dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
    } else {
      pointeePtrType = inputPtrType;
    }

    if (!pointeePtrType || pointeePtrType.getStorageClass() !=
                               spirv::StorageClass::PhysicalStorageBuffer)
      continue;

    bool hasAliased = hasDecorationAttr(spirv::Decoration::Aliased, i);
    bool hasRestrict = hasDecorationAttr(spirv::Decoration::Restrict, i);
    if (!hasAliased && !hasRestrict)
      return emitOpError() << "with physical buffer pointer must be decorated "
                              "either 'Aliased' or 'Restrict'";
  }

  return success();
}

LogicalResult spirv::FuncOp::verifyBody() {
  FunctionType fnType = getFunctionType();

  auto walkResult = walk([fnType](Operation *op) -> WalkResult {
    if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
      if (fnType.getNumResults() != 0)
        return retOp.emitOpError("cannot be used in functions returning value");
    } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
      if (fnType.getNumResults() != 1)
        return retOp.emitOpError(
                   "returns 1 value but enclosing function requires ")
               << fnType.getNumResults() << " results";

      auto retOperandType = retOp.getValue().getType();
      auto fnResultType = fnType.getResult(0);
      if (retOperandType != fnResultType)
        return retOp.emitOpError(" return value's type (")
               << retOperandType << ") mismatch with function's result type ("
               << fnResultType << ")";
    }
    return WalkResult::advance();
  });

  // TODO: verify other bits like linkage type.

  return failure(walkResult.wasInterrupted());
}

void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
                          StringRef name, FunctionType type,
                          spirv::FunctionControl control,
                          ArrayRef<NamedAttribute> attrs) {
  state.addAttribute(SymbolTable::getSymbolAttrName(),
                     builder.getStringAttr(name));
  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
  state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
                     builder.getAttr<spirv::FunctionControlAttr>(control));
  state.attributes.append(attrs.begin(), attrs.end());
  state.addRegion();
}

//===----------------------------------------------------------------------===//
// spirv.GLFClampOp
//===----------------------------------------------------------------------===//

ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
                                     OperationState &result) {
  return parseOneResultSameOperandTypeOp(parser, result);
}
void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }

//===----------------------------------------------------------------------===//
// spirv.GLUClampOp
//===----------------------------------------------------------------------===//

ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
                                     OperationState &result) {
  return parseOneResultSameOperandTypeOp(parser, result);
}
void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }

//===----------------------------------------------------------------------===//
// spirv.GLSClampOp
//===----------------------------------------------------------------------===//

ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
                                     OperationState &result) {
  return parseOneResultSameOperandTypeOp(parser, result);
}
void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }

//===----------------------------------------------------------------------===//
// spirv.GLFmaOp
//===----------------------------------------------------------------------===//

ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
  return parseOneResultSameOperandTypeOp(parser, result);
}
void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }

//===----------------------------------------------------------------------===//
// spirv.GlobalVariable
//===----------------------------------------------------------------------===//

void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
                                    Type type, StringRef name,
                                    unsigned descriptorSet, unsigned binding) {
  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
  state.addAttribute(
      spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
      builder.getI32IntegerAttr(descriptorSet));
  state.addAttribute(
      spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
      builder.getI32IntegerAttr(binding));
}

void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
                                    Type type, StringRef name,
                                    spirv::BuiltIn builtin) {
  build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
  state.addAttribute(
      spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
      builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
}

ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
                                           OperationState &result) {
  // Parse variable name.
  StringAttr nameAttr;
  StringRef initializerAttrName =
      spirv::GlobalVariableOp::getInitializerAttrName(result.name);
  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                             result.attributes)) {
    return failure();
  }

  // Parse optional initializer
  if (succeeded(parser.parseOptionalKeyword(initializerAttrName))) {
    FlatSymbolRefAttr initSymbol;
    if (parser.parseLParen() ||
        parser.parseAttribute(initSymbol, Type(), initializerAttrName,
                              result.attributes) ||
        parser.parseRParen())
      return failure();
  }

  if (parseVariableDecorations(parser, result)) {
    return failure();
  }

  Type type;
  StringRef typeAttrName =
      spirv::GlobalVariableOp::getTypeAttrName(result.name);
  auto loc = parser.getCurrentLocation();
  if (parser.parseColonType(type)) {
    return failure();
  }
  if (!llvm::isa<spirv::PointerType>(type)) {
    return parser.emitError(loc, "expected spirv.ptr type");
  }
  result.addAttribute(typeAttrName, TypeAttr::get(type));

  return success();
}

void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
  SmallVector<StringRef, 4> elidedAttrs{
      spirv::attributeName<spirv::StorageClass>()};

  // Print variable name.
  printer << ' ';
  printer.printSymbolName(getSymName());
  elidedAttrs.push_back(SymbolTable::getSymbolAttrName());

  StringRef initializerAttrName = this->getInitializerAttrName();
  // Print optional initializer
  if (auto initializer = this->getInitializer()) {
    printer << " " << initializerAttrName << '(';
    printer.printSymbolName(*initializer);
    printer << ')';
    elidedAttrs.push_back(initializerAttrName);
  }

  StringRef typeAttrName = this->getTypeAttrName();
  elidedAttrs.push_back(typeAttrName);
  spirv::printVariableDecorations(*this, printer, elidedAttrs);
  printer << " : " << getType();
}

LogicalResult spirv::GlobalVariableOp::verify() {
  if (!llvm::isa<spirv::PointerType>(getType()))
    return emitOpError("result must be of a !spv.ptr type");

  // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
  // object. It cannot be Generic. It must be the same as the Storage Class
  // operand of the Result Type."
  // Also, Function storage class is reserved by spirv.Variable.
  auto storageClass = this->storageClass();
  if (storageClass == spirv::StorageClass::Generic ||
      storageClass == spirv::StorageClass::Function) {
    return emitOpError("storage class cannot be '")
           << stringifyStorageClass(storageClass) << "'";
  }

  if (auto init = (*this)->getAttrOfType<FlatSymbolRefAttr>(
          this->getInitializerAttrName())) {
    Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
        (*this)->getParentOp(), init.getAttr());
    // TODO: Currently only variable initialization with specialization
    // constants and other variables is supported. They could be normal
    // constants in the module scope as well.
    if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
                        spirv::SpecConstantCompositeOp>(initOp)) {
      return emitOpError("initializer must be result of a "
                         "spirv.SpecConstant or spirv.GlobalVariable or "
                         "spirv.SpecConstantCompositeOp op");
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.INTEL.SubgroupBlockRead
//===----------------------------------------------------------------------===//

ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
                                                   OperationState &result) {
  // Parse the storage class specification
  spirv::StorageClass storageClass;
  OpAsmParser::UnresolvedOperand ptrInfo;
  Type elementType;
  if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
      parser.parseColon() || parser.parseType(elementType)) {
    return failure();
  }

  auto ptrType = spirv::PointerType::get(elementType, storageClass);
  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
    ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);

  if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
    return failure();
  }

  result.addTypes(elementType);
  return success();
}

void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
  printer << " " << getPtr() << " : " << getType();
}

LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
    return failure();

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.INTEL.SubgroupBlockWrite
//===----------------------------------------------------------------------===//

ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
                                                    OperationState &result) {
  // Parse the storage class specification
  spirv::StorageClass storageClass;
  SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
  auto loc = parser.getCurrentLocation();
  Type elementType;
  if (parseEnumStrAttr(storageClass, parser) ||
      parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
      parser.parseType(elementType)) {
    return failure();
  }

  auto ptrType = spirv::PointerType::get(elementType, storageClass);
  if (auto valVecTy = llvm::dyn_cast<VectorType>(elementType))
    ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);

  if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
                             result.operands)) {
    return failure();
  }
  return success();
}

void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
  printer << " " << getPtr() << ", " << getValue() << " : "
          << getValue().getType();
}

LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
  if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
    return failure();

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.IAddCarryOp
//===----------------------------------------------------------------------===//

LogicalResult spirv::IAddCarryOp::verify() {
  return ::verifyArithmeticExtendedBinaryOp(*this);
}

ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
                                      OperationState &result) {
  return ::parseArithmeticExtendedBinaryOp(parser, result);
}

void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
  ::printArithmeticExtendedBinaryOp(*this, printer);
}

//===----------------------------------------------------------------------===//
// spirv.ISubBorrowOp
//===----------------------------------------------------------------------===//

LogicalResult spirv::ISubBorrowOp::verify() {
  return ::verifyArithmeticExtendedBinaryOp(*this);
}

ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
                                       OperationState &result) {
  return ::parseArithmeticExtendedBinaryOp(parser, result);
}

void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
  ::printArithmeticExtendedBinaryOp(*this, printer);
}

//===----------------------------------------------------------------------===//
// spirv.SMulExtended
//===----------------------------------------------------------------------===//

LogicalResult spirv::SMulExtendedOp::verify() {
  return ::verifyArithmeticExtendedBinaryOp(*this);
}

ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
                                         OperationState &result) {
  return ::parseArithmeticExtendedBinaryOp(parser, result);
}

void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
  ::printArithmeticExtendedBinaryOp(*this, printer);
}

//===----------------------------------------------------------------------===//
// spirv.UMulExtended
//===----------------------------------------------------------------------===//

LogicalResult spirv::UMulExtendedOp::verify() {
  return ::verifyArithmeticExtendedBinaryOp(*this);
}

ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
                                         OperationState &result) {
  return ::parseArithmeticExtendedBinaryOp(parser, result);
}

void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
  ::printArithmeticExtendedBinaryOp(*this, printer);
}

//===----------------------------------------------------------------------===//
// spirv.MemoryBarrierOp
//===----------------------------------------------------------------------===//

LogicalResult spirv::MemoryBarrierOp::verify() {
  return verifyMemorySemantics(getOperation(), getMemorySemantics());
}

//===----------------------------------------------------------------------===//
// spirv.module
//===----------------------------------------------------------------------===//

void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
                            std::optional<StringRef> name) {
  OpBuilder::InsertionGuard guard(builder);
  builder.createBlock(state.addRegion());
  if (name) {
    state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
                            builder.getStringAttr(*name));
  }
}

void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
                            spirv::AddressingModel addressingModel,
                            spirv::MemoryModel memoryModel,
                            std::optional<VerCapExtAttr> vceTriple,
                            std::optional<StringRef> name) {
  state.addAttribute(
      "addressing_model",
      builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
  state.addAttribute("memory_model",
                     builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
  OpBuilder::InsertionGuard guard(builder);
  builder.createBlock(state.addRegion());
  if (vceTriple)
    state.addAttribute(getVCETripleAttrName(), *vceTriple);
  if (name)
    state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
                       builder.getStringAttr(*name));
}

ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
                                   OperationState &result) {
  Region *body = result.addRegion();

  // If the name is present, parse it.
  StringAttr nameAttr;
  (void)parser.parseOptionalSymbolName(
      nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);

  // Parse attributes
  spirv::AddressingModel addrModel;
  spirv::MemoryModel memoryModel;
  if (spirv::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
                                                              result) ||
      spirv::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
                                                          result))
    return failure();

  if (succeeded(parser.parseOptionalKeyword("requires"))) {
    spirv::VerCapExtAttr vceTriple;
    if (parser.parseAttribute(vceTriple,
                              spirv::ModuleOp::getVCETripleAttrName(),
                              result.attributes))
      return failure();
  }

  if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
      parser.parseRegion(*body, /*arguments=*/{}))
    return failure();

  // Make sure we have at least one block.
  if (body->empty())
    body->push_back(new Block());

  return success();
}

void spirv::ModuleOp::print(OpAsmPrinter &printer) {
  if (std::optional<StringRef> name = getName()) {
    printer << ' ';
    printer.printSymbolName(*name);
  }

  SmallVector<StringRef, 2> elidedAttrs;

  printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
          << spirv::stringifyMemoryModel(getMemoryModel());
  auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
  auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
  elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
                      mlir::SymbolTable::getSymbolAttrName()});

  if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
    printer << " requires " << *triple;
    elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
  }

  printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
  printer << ' ';
  printer.printRegion(getRegion());
}

LogicalResult spirv::ModuleOp::verifyRegions() {
  Dialect *dialect = (*this)->getDialect();
  DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
      entryPoints;
  mlir::SymbolTable table(*this);

  for (auto &op : *getBody()) {
    if (op.getDialect() != dialect)
      return op.emitError("'spirv.module' can only contain spirv.* ops");

    // For EntryPoint op, check that the function and execution model is not
    // duplicated in EntryPointOps. Also verify that the interface specified
    // comes from globalVariables here to make this check cheaper.
    if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
      auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
      if (!funcOp) {
        return entryPointOp.emitError("function '")
               << entryPointOp.getFn() << "' not found in 'spirv.module'";
      }
      if (auto interface = entryPointOp.getInterface()) {
        for (Attribute varRef : interface) {
          auto varSymRef = llvm::dyn_cast<FlatSymbolRefAttr>(varRef);
          if (!varSymRef) {
            return entryPointOp.emitError(
                       "expected symbol reference for interface "
                       "specification instead of '")
                   << varRef;
          }
          auto variableOp =
              table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
          if (!variableOp) {
            return entryPointOp.emitError("expected spirv.GlobalVariable "
                                          "symbol reference instead of'")
                   << varSymRef << "'";
          }
        }
      }

      auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
          funcOp, entryPointOp.getExecutionModel());
      auto entryPtIt = entryPoints.find(key);
      if (entryPtIt != entryPoints.end()) {
        return entryPointOp.emitError("duplicate of a previous EntryPointOp");
      }
      entryPoints[key] = entryPointOp;
    } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
      // If the function is external and does not have 'Import'
      // linkage_attributes(LinkageAttributes), throw an error. 'Import'
      // LinkageAttributes is used to import external functions.
      auto linkageAttr = funcOp.getLinkageAttributes();
      auto hasImportLinkage =
          linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
                          spirv::LinkageType::Import);
      if (funcOp.isExternal() && !hasImportLinkage)
        return op.emitError(
            "'spirv.module' cannot contain external functions "
            "without 'Import' linkage_attributes (LinkageAttributes)");

      // TODO: move this check to spirv.func.
      for (auto &block : funcOp)
        for (auto &op : block) {
          if (op.getDialect() != dialect)
            return op.emitError(
                "functions in 'spirv.module' can only contain spirv.* ops");
        }
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.mlir.referenceof
//===----------------------------------------------------------------------===//

LogicalResult spirv::ReferenceOfOp::verify() {
  auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
      (*this)->getParentOp(), getSpecConstAttr());
  Type constType;

  auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
  if (specConstOp)
    constType = specConstOp.getDefaultValue().getType();

  auto specConstCompositeOp =
      dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
  if (specConstCompositeOp)
    constType = specConstCompositeOp.getType();

  if (!specConstOp && !specConstCompositeOp)
    return emitOpError(
        "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");

  if (getReference().getType() != constType)
    return emitOpError("result type mismatch with the referenced "
                       "specialization constant's type");

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.SpecConstant
//===----------------------------------------------------------------------===//

ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
                                         OperationState &result) {
  StringAttr nameAttr;
  Attribute valueAttr;
  StringRef defaultValueAttrName =
      spirv::SpecConstantOp::getDefaultValueAttrName(result.name);

  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
                             result.attributes))
    return failure();

  // Parse optional spec_id.
  if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
    IntegerAttr specIdAttr;
    if (parser.parseLParen() ||
        parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
        parser.parseRParen())
      return failure();
  }

  if (parser.parseEqual() ||
      parser.parseAttribute(valueAttr, defaultValueAttrName, result.attributes))
    return failure();

  return success();
}

void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
  printer << ' ';
  printer.printSymbolName(getSymName());
  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
    printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
  printer << " = " << getDefaultValue();
}

LogicalResult spirv::SpecConstantOp::verify() {
  if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
    if (specID.getValue().isNegative())
      return emitOpError("SpecId cannot be negative");

  auto value = getDefaultValue();
  if (llvm::isa<IntegerAttr, FloatAttr>(value)) {
    // Make sure bitwidth is allowed.
    if (!llvm::isa<spirv::SPIRVType>(value.getType()))
      return emitOpError("default value bitwidth disallowed");
    return success();
  }
  return emitOpError(
      "default value can only be a bool, integer, or float scalar");
}

//===----------------------------------------------------------------------===//
// spirv.VectorShuffle
//===----------------------------------------------------------------------===//

LogicalResult spirv::VectorShuffleOp::verify() {
  VectorType resultType = llvm::cast<VectorType>(getType());

  size_t numResultElements = resultType.getNumElements();
  if (numResultElements != getComponents().size())
    return emitOpError("result type element count (")
           << numResultElements
           << ") mismatch with the number of component selectors ("
           << getComponents().size() << ")";

  size_t totalSrcElements =
      llvm::cast<VectorType>(getVector1().getType()).getNumElements() +
      llvm::cast<VectorType>(getVector2().getType()).getNumElements();

  for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
    uint32_t index = selector.getZExtValue();
    if (index >= totalSrcElements &&
        index != std::numeric_limits<uint32_t>().max())
      return emitOpError("component selector ")
             << index << " out of range: expected to be in [0, "
             << totalSrcElements << ") or 0xffffffff";
  }
  return success();
}

//===----------------------------------------------------------------------===//
// spirv.MatrixTimesScalar
//===----------------------------------------------------------------------===//

LogicalResult spirv::MatrixTimesScalarOp::verify() {
  Type elementType =
      llvm::TypeSwitch<Type, Type>(getMatrix().getType())
          .Case<spirv::CooperativeMatrixType, spirv::MatrixType>(
              [](auto matrixType) { return matrixType.getElementType(); })
          .Default([](Type) { return nullptr; });

  assert(elementType && "Unhandled type");

  // Check that the scalar type is the same as the matrix element type.
  if (getScalar().getType() != elementType)
    return emitOpError("input matrix components' type and scaling value must "
                       "have the same type");

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.Transpose
//===----------------------------------------------------------------------===//

LogicalResult spirv::TransposeOp::verify() {
  auto inputMatrix = llvm::cast<spirv::MatrixType>(getMatrix().getType());
  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());

  // Verify that the input and output matrices have correct shapes.
  if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
    return emitError("input matrix rows count must be equal to "
                     "output matrix columns count");

  if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
    return emitError("input matrix columns count must be equal to "
                     "output matrix rows count");

  // Verify that the input and output matrices have the same component type
  if (inputMatrix.getElementType() != resultMatrix.getElementType())
    return emitError("input and output matrices must have the same "
                     "component type");

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.MatrixTimesMatrix
//===----------------------------------------------------------------------===//

LogicalResult spirv::MatrixTimesMatrixOp::verify() {
  auto leftMatrix = llvm::cast<spirv::MatrixType>(getLeftmatrix().getType());
  auto rightMatrix = llvm::cast<spirv::MatrixType>(getRightmatrix().getType());
  auto resultMatrix = llvm::cast<spirv::MatrixType>(getResult().getType());

  // left matrix columns' count and right matrix rows' count must be equal
  if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
    return emitError("left matrix columns' count must be equal to "
                     "the right matrix rows' count");

  // right and result matrices columns' count must be the same
  if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
    return emitError(
        "right and result matrices must have equal columns' count");

  // right and result matrices component type must be the same
  if (rightMatrix.getElementType() != resultMatrix.getElementType())
    return emitError("right and result matrices' component type must"
                     " be the same");

  // left and result matrices component type must be the same
  if (leftMatrix.getElementType() != resultMatrix.getElementType())
    return emitError("left and result matrices' component type"
                     " must be the same");

  // left and result matrices rows count must be the same
  if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
    return emitError("left and result matrices must have equal rows' count");

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.SpecConstantComposite
//===----------------------------------------------------------------------===//

ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
                                                  OperationState &result) {

  StringAttr compositeName;
  if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
                             result.attributes))
    return failure();

  if (parser.parseLParen())
    return failure();

  SmallVector<Attribute, 4> constituents;

  do {
    // The name of the constituent attribute isn't important
    const char *attrName = "spec_const";
    FlatSymbolRefAttr specConstRef;
    NamedAttrList attrs;

    if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
      return failure();

    constituents.push_back(specConstRef);
  } while (!parser.parseOptionalComma());

  if (parser.parseRParen())
    return failure();

  StringAttr compositeSpecConstituentsName =
      spirv::SpecConstantCompositeOp::getConstituentsAttrName(result.name);
  result.addAttribute(compositeSpecConstituentsName,
                      parser.getBuilder().getArrayAttr(constituents));

  Type type;
  if (parser.parseColonType(type))
    return failure();

  StringAttr typeAttrName =
      spirv::SpecConstantCompositeOp::getTypeAttrName(result.name);
  result.addAttribute(typeAttrName, TypeAttr::get(type));

  return success();
}

void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
  printer << " ";
  printer.printSymbolName(getSymName());
  printer << " (";
  auto constituents = this->getConstituents().getValue();

  if (!constituents.empty())
    llvm::interleaveComma(constituents, printer);

  printer << ") : " << getType();
}

LogicalResult spirv::SpecConstantCompositeOp::verify() {
  auto cType = llvm::dyn_cast<spirv::CompositeType>(getType());
  auto constituents = this->getConstituents().getValue();

  if (!cType)
    return emitError("result type must be a composite type, but provided ")
           << getType();

  if (llvm::isa<spirv::CooperativeMatrixType>(cType))
    return emitError("unsupported composite type  ") << cType;
  if (llvm::isa<spirv::JointMatrixINTELType>(cType))
    return emitError("unsupported composite type  ") << cType;
  if (constituents.size() != cType.getNumElements())
    return emitError("has incorrect number of operands: expected ")
           << cType.getNumElements() << ", but provided "
           << constituents.size();

  for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
    auto constituent = llvm::cast<FlatSymbolRefAttr>(constituents[index]);

    auto constituentSpecConstOp =
        dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
            (*this)->getParentOp(), constituent.getAttr()));

    if (constituentSpecConstOp.getDefaultValue().getType() !=
        cType.getElementType(index))
      return emitError("has incorrect types of operands: expected ")
             << cType.getElementType(index) << ", but provided "
             << constituentSpecConstOp.getDefaultValue().getType();
  }

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.SpecConstantOperation
//===----------------------------------------------------------------------===//

ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
                                                  OperationState &result) {
  Region *body = result.addRegion();

  if (parser.parseKeyword("wraps"))
    return failure();

  body->push_back(new Block);
  Block &block = body->back();
  Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());

  if (!wrappedOp)
    return failure();

  OpBuilder builder(parser.getContext());
  builder.setInsertionPointToEnd(&block);
  builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
  result.location = wrappedOp->getLoc();

  result.addTypes(wrappedOp->getResult(0).getType());

  if (parser.parseOptionalAttrDict(result.attributes))
    return failure();

  return success();
}

void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
  printer << " wraps ";
  printer.printGenericOp(&getBody().front().front());
}

LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
  Block &block = getRegion().getBlocks().front();

  if (block.getOperations().size() != 2)
    return emitOpError("expected exactly 2 nested ops");

  Operation &enclosedOp = block.getOperations().front();

  if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
    return emitOpError("invalid enclosed op");

  for (auto operand : enclosedOp.getOperands())
    if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
             spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
      return emitOpError(
          "invalid operand, must be defined by a constant operation");

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.GL.FrexpStruct
//===----------------------------------------------------------------------===//

LogicalResult spirv::GLFrexpStructOp::verify() {
  spirv::StructType structTy =
      llvm::dyn_cast<spirv::StructType>(getResult().getType());

  if (structTy.getNumElements() != 2)
    return emitError("result type must be a struct type with two memebers");

  Type significandTy = structTy.getElementType(0);
  Type exponentTy = structTy.getElementType(1);
  VectorType exponentVecTy = llvm::dyn_cast<VectorType>(exponentTy);
  IntegerType exponentIntTy = llvm::dyn_cast<IntegerType>(exponentTy);

  Type operandTy = getOperand().getType();
  VectorType operandVecTy = llvm::dyn_cast<VectorType>(operandTy);
  FloatType operandFTy = llvm::dyn_cast<FloatType>(operandTy);

  if (significandTy != operandTy)
    return emitError("member zero of the resulting struct type must be the "
                     "same type as the operand");

  if (exponentVecTy) {
    IntegerType componentIntTy =
        llvm::dyn_cast<IntegerType>(exponentVecTy.getElementType());
    if (!componentIntTy || componentIntTy.getWidth() != 32)
      return emitError("member one of the resulting struct type must"
                       "be a scalar or vector of 32 bit integer type");
  } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
    return emitError("member one of the resulting struct type "
                     "must be a scalar or vector of 32 bit integer type");
  }

  // Check that the two member types have the same number of components
  if (operandVecTy && exponentVecTy &&
      (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
    return success();

  if (operandFTy && exponentIntTy)
    return success();

  return emitError("member one of the resulting struct type must have the same "
                   "number of components as the operand type");
}

//===----------------------------------------------------------------------===//
// spirv.GL.Ldexp
//===----------------------------------------------------------------------===//

LogicalResult spirv::GLLdexpOp::verify() {
  Type significandType = getX().getType();
  Type exponentType = getExp().getType();

  if (llvm::isa<FloatType>(significandType) !=
      llvm::isa<IntegerType>(exponentType))
    return emitOpError("operands must both be scalars or vectors");

  auto getNumElements = [](Type type) -> unsigned {
    if (auto vectorType = llvm::dyn_cast<VectorType>(type))
      return vectorType.getNumElements();
    return 1;
  };

  if (getNumElements(significandType) != getNumElements(exponentType))
    return emitOpError("operands must have the same number of elements");

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.ImageDrefGather
//===----------------------------------------------------------------------===//

LogicalResult spirv::ImageDrefGatherOp::verify() {
  VectorType resultType = llvm::cast<VectorType>(getResult().getType());
  auto sampledImageType =
      llvm::cast<spirv::SampledImageType>(getSampledimage().getType());
  auto imageType =
      llvm::cast<spirv::ImageType>(sampledImageType.getImageType());

  if (resultType.getNumElements() != 4)
    return emitOpError("result type must be a vector of four components");

  Type elementType = resultType.getElementType();
  Type sampledElementType = imageType.getElementType();
  if (!llvm::isa<NoneType>(sampledElementType) &&
      elementType != sampledElementType)
    return emitOpError(
        "the component type of result must be the same as sampled type of the "
        "underlying image type");

  spirv::Dim imageDim = imageType.getDim();
  spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();

  if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
      imageDim != spirv::Dim::Rect)
    return emitOpError(
        "the Dim operand of the underlying image type must be 2D, Cube, or "
        "Rect");

  if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
    return emitOpError("the MS operand of the underlying image type must be 0");

  spirv::ImageOperandsAttr attr = getImageoperandsAttr();
  auto operandArguments = getOperandArguments();

  return verifyImageOperands(*this, attr, operandArguments);
}

//===----------------------------------------------------------------------===//
// spirv.ShiftLeftLogicalOp
//===----------------------------------------------------------------------===//

LogicalResult spirv::ShiftLeftLogicalOp::verify() {
  return verifyShiftOp(*this);
}

//===----------------------------------------------------------------------===//
// spirv.ShiftRightArithmeticOp
//===----------------------------------------------------------------------===//

LogicalResult spirv::ShiftRightArithmeticOp::verify() {
  return verifyShiftOp(*this);
}

//===----------------------------------------------------------------------===//
// spirv.ShiftRightLogicalOp
//===----------------------------------------------------------------------===//

LogicalResult spirv::ShiftRightLogicalOp::verify() {
  return verifyShiftOp(*this);
}

//===----------------------------------------------------------------------===//
// spirv.ImageQuerySize
//===----------------------------------------------------------------------===//

LogicalResult spirv::ImageQuerySizeOp::verify() {
  spirv::ImageType imageType =
      llvm::cast<spirv::ImageType>(getImage().getType());
  Type resultType = getResult().getType();

  spirv::Dim dim = imageType.getDim();
  spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
  spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
  switch (dim) {
  case spirv::Dim::Dim1D:
  case spirv::Dim::Dim2D:
  case spirv::Dim::Dim3D:
  case spirv::Dim::Cube:
    if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
        samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
        samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
      return emitError(
          "if Dim is 1D, 2D, 3D, or Cube, "
          "it must also have either an MS of 1 or a Sampled of 0 or 2");
    break;
  case spirv::Dim::Buffer:
  case spirv::Dim::Rect:
    break;
  default:
    return emitError("the Dim operand of the image type must "
                     "be 1D, 2D, 3D, Buffer, Cube, or Rect");
  }

  unsigned componentNumber = 0;
  switch (dim) {
  case spirv::Dim::Dim1D:
  case spirv::Dim::Buffer:
    componentNumber = 1;
    break;
  case spirv::Dim::Dim2D:
  case spirv::Dim::Cube:
  case spirv::Dim::Rect:
    componentNumber = 2;
    break;
  case spirv::Dim::Dim3D:
    componentNumber = 3;
    break;
  default:
    break;
  }

  if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
    componentNumber += 1;

  unsigned resultComponentNumber = 1;
  if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
    resultComponentNumber = resultVectorType.getNumElements();

  if (componentNumber != resultComponentNumber)
    return emitError("expected the result to have ")
           << componentNumber << " component(s), but found "
           << resultComponentNumber << " component(s)";

  return success();
}

//===----------------------------------------------------------------------===//
// spirv.VectorTimesScalarOp
//===----------------------------------------------------------------------===//

LogicalResult spirv::VectorTimesScalarOp::verify() {
  if (getVector().getType() != getType())
    return emitOpError("vector operand and result type mismatch");
  auto scalarType = llvm::cast<VectorType>(getType()).getElementType();
  if (getScalar().getType() != scalarType)
    return emitOpError("scalar operand and result element type match");
  return success();
}