//===- UnifyAliasedResourcePass.cpp - Pass to Unify Aliased Resources -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a pass that unifies access of multiple aliased resources
// into access of one single resource.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/Transforms/Passes.h"

#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include <algorithm>
#include <iterator>

namespace mlir {
namespace spirv {
#define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
} // namespace spirv
} // namespace mlir

#define DEBUG_TYPE "spirv-unify-aliased-resource"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

using Descriptor = std::pair<uint32_t, uint32_t>; // (set #, binding #)
using AliasedResourceMap =
    DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;

/// Collects all aliased resources in the given SPIR-V `moduleOp`.
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
  AliasedResourceMap aliasedResources;
  moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
    if (varOp->getAttrOfType<UnitAttr>("aliased")) {
      std::optional<uint32_t> set = varOp.getDescriptorSet();
      std::optional<uint32_t> binding = varOp.getBinding();
      if (set && binding)
        aliasedResources[{*set, *binding}].push_back(varOp);
    }
  });
  return aliasedResources;
}

/// Returns the element type if the given `type` is a runtime array resource:
/// `!spirv.ptr<!spirv.struct<!spirv.rtarray<...>>>`. Returns null type
/// otherwise.
static Type getRuntimeArrayElementType(Type type) {
  auto ptrType = dyn_cast<spirv::PointerType>(type);
  if (!ptrType)
    return {};

  auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
  if (!structType || structType.getNumElements() != 1)
    return {};

  auto rtArrayType =
      dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
  if (!rtArrayType)
    return {};

  return rtArrayType.getElementType();
}

/// Given a list of resource element `types`, returns the index of the canonical
/// resource that all resources should be unified into. Returns std::nullopt if
/// unable to unify.
static std::optional<int>
deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
  // scalarNumBits: contains all resources' scalar types' bit counts.
  // vectorNumBits: only contains resources whose element types are vectors.
  // vectorIndices: each vector's original index in `types`.
  SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
  scalarNumBits.reserve(types.size());
  vectorNumBits.reserve(types.size());
  vectorIndices.reserve(types.size());

  for (const auto &indexedTypes : llvm::enumerate(types)) {
    spirv::SPIRVType type = indexedTypes.value();
    assert(type.isScalarOrVector());
    if (auto vectorType = dyn_cast<VectorType>(type)) {
      if (vectorType.getNumElements() % 2 != 0)
        return std::nullopt; // Odd-sized vector has special layout
                             // requirements.

      std::optional<int64_t> numBytes = type.getSizeInBytes();
      if (!numBytes)
        return std::nullopt;

      scalarNumBits.push_back(
          vectorType.getElementType().getIntOrFloatBitWidth());
      vectorNumBits.push_back(*numBytes * 8);
      vectorIndices.push_back(indexedTypes.index());
    } else {
      scalarNumBits.push_back(type.getIntOrFloatBitWidth());
    }
  }

  if (!vectorNumBits.empty()) {
    // Choose the *vector* with the smallest bitwidth as the canonical resource,
    // so that we can still keep vectorized load/store and avoid partial updates
    // to large vectors.
    auto *minVal = llvm::min_element(vectorNumBits);
    // Make sure that the canonical resource's bitwidth is divisible by others.
    // With out this, we cannot properly adjust the index later.
    if (llvm::any_of(vectorNumBits,
                     [&](int bits) { return bits % *minVal != 0; }))
      return std::nullopt;

    // Require all scalar type bit counts to be a multiple of the chosen
    // vector's primitive type to avoid reading/writing subcomponents.
    int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
    int baseNumBits = scalarNumBits[index];
    if (llvm::any_of(scalarNumBits,
                     [&](int bits) { return bits % baseNumBits != 0; }))
      return std::nullopt;

    return index;
  }

  // All element types are scalars. Then choose the smallest bitwidth as the
  // cannonical resource to avoid subcomponent load/store.
  auto *minVal = llvm::min_element(scalarNumBits);
  if (llvm::any_of(scalarNumBits,
                   [minVal](int64_t bit) { return bit % *minVal != 0; }))
    return std::nullopt;
  return std::distance(scalarNumBits.begin(), minVal);
}

static bool areSameBitwidthScalarType(Type a, Type b) {
  return a.isIntOrFloat() && b.isIntOrFloat() &&
         a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
}

//===----------------------------------------------------------------------===//
// Analysis
//===----------------------------------------------------------------------===//

namespace {
/// A class for analyzing aliased resources.
///
/// Resources are expected to be spirv.GlobalVarible that has a descriptor set
/// and binding number. Such resources are of the type
/// `!spirv.ptr<!spirv.struct<...>>` per Vulkan requirements.
///
/// Right now, we only support the case that there is a single runtime array
/// inside the struct.
class ResourceAliasAnalysis {
public:
  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)

  explicit ResourceAliasAnalysis(Operation *);

  /// Returns true if the given `op` can be rewritten to use a canonical
  /// resource.
  bool shouldUnify(Operation *op) const;

  /// Returns all descriptors and their corresponding aliased resources.
  const AliasedResourceMap &getResourceMap() const { return resourceMap; }

  /// Returns the canonical resource for the given descriptor/variable.
  spirv::GlobalVariableOp
  getCanonicalResource(const Descriptor &descriptor) const;
  spirv::GlobalVariableOp
  getCanonicalResource(spirv::GlobalVariableOp varOp) const;

  /// Returns the element type for the given variable.
  spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;

private:
  /// Given the descriptor and aliased resources bound to it, analyze whether we
  /// can unify them and record if so.
  void recordIfUnifiable(const Descriptor &descriptor,
                         ArrayRef<spirv::GlobalVariableOp> resources);

  /// Mapping from a descriptor to all aliased resources bound to it.
  AliasedResourceMap resourceMap;

  /// Mapping from a descriptor to the chosen canonical resource.
  DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;

  /// Mapping from an aliased resource to its descriptor.
  DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;

  /// Mapping from an aliased resource to its element (scalar/vector) type.
  DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
};
} // namespace

ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
  // Collect all aliased resources first and put them into different sets
  // according to the descriptor.
  AliasedResourceMap aliasedResources =
      collectAliasedResources(cast<spirv::ModuleOp>(root));

  // For each resource set, analyze whether we can unify; if so, try to identify
  // a canonical resource, whose element type has the largest bitwidth.
  for (const auto &descriptorResource : aliasedResources) {
    recordIfUnifiable(descriptorResource.first, descriptorResource.second);
  }
}

bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
  if (!op)
    return false;

  if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
    auto canonicalOp = getCanonicalResource(varOp);
    return canonicalOp && varOp != canonicalOp;
  }
  if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
    auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
    auto *varOp =
        SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
    return shouldUnify(varOp);
  }

  if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
    return shouldUnify(acOp.getBasePtr().getDefiningOp());
  if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
    return shouldUnify(loadOp.getPtr().getDefiningOp());
  if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
    return shouldUnify(storeOp.getPtr().getDefiningOp());

  return false;
}

spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
    const Descriptor &descriptor) const {
  auto varIt = canonicalResourceMap.find(descriptor);
  if (varIt == canonicalResourceMap.end())
    return {};
  return varIt->second;
}

spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
    spirv::GlobalVariableOp varOp) const {
  auto descriptorIt = descriptorMap.find(varOp);
  if (descriptorIt == descriptorMap.end())
    return {};
  return getCanonicalResource(descriptorIt->second);
}

spirv::SPIRVType
ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
  auto it = elementTypeMap.find(varOp);
  if (it == elementTypeMap.end())
    return {};
  return it->second;
}

void ResourceAliasAnalysis::recordIfUnifiable(
    const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
  // Collect the element types for all resources in the current set.
  SmallVector<spirv::SPIRVType> elementTypes;
  for (spirv::GlobalVariableOp resource : resources) {
    Type elementType = getRuntimeArrayElementType(resource.getType());
    if (!elementType)
      return; // Unexpected resource variable type.

    auto type = cast<spirv::SPIRVType>(elementType);
    if (!type.isScalarOrVector())
      return; // Unexpected resource element type.

    elementTypes.push_back(type);
  }

  std::optional<int> index = deduceCanonicalResource(elementTypes);
  if (!index)
    return;

  // Update internal data structures for later use.
  resourceMap[descriptor].assign(resources.begin(), resources.end());
  canonicalResourceMap[descriptor] = resources[*index];
  for (const auto &resource : llvm::enumerate(resources)) {
    descriptorMap[resource.value()] = descriptor;
    elementTypeMap[resource.value()] = elementTypes[resource.index()];
  }
}

//===----------------------------------------------------------------------===//
// Patterns
//===----------------------------------------------------------------------===//

template <typename OpTy>
class ConvertAliasResource : public OpConversionPattern<OpTy> {
public:
  ConvertAliasResource(const ResourceAliasAnalysis &analysis,
                       MLIRContext *context, PatternBenefit benefit = 1)
      : OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}

protected:
  const ResourceAliasAnalysis &analysis;
};

struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
  using ConvertAliasResource::ConvertAliasResource;

  LogicalResult
  matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Just remove the aliased resource. Users will be rewritten to use the
    // canonical one.
    rewriter.eraseOp(varOp);
    return success();
  }
};

struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
  using ConvertAliasResource::ConvertAliasResource;

  LogicalResult
  matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    // Rewrite the AddressOf op to get the address of the canoncical resource.
    auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
    auto srcVarOp = cast<spirv::GlobalVariableOp>(
        SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
    auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
    rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
    return success();
  }
};

struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
  using ConvertAliasResource::ConvertAliasResource;

  LogicalResult
  matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
    if (!addressOp)
      return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");

    auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
    auto srcVarOp = cast<spirv::GlobalVariableOp>(
        SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
    auto dstVarOp = analysis.getCanonicalResource(srcVarOp);

    spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
    spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);

    if (srcElemType == dstElemType ||
        areSameBitwidthScalarType(srcElemType, dstElemType)) {
      // We have the same bitwidth for source and destination element types.
      // Thie indices keep the same.
      rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
          acOp, adaptor.getBasePtr(), adaptor.getIndices());
      return success();
    }

    Location loc = acOp.getLoc();

    if (srcElemType.isIntOrFloat() && isa<VectorType>(dstElemType)) {
      // The source indices are for a buffer with scalar element types. Rewrite
      // them into a buffer with vector element types. We need to scale the last
      // index for the vector as a whole, then add one level of index for inside
      // the vector.
      int srcNumBytes = *srcElemType.getSizeInBytes();
      int dstNumBytes = *dstElemType.getSizeInBytes();
      assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);

      auto indices = llvm::to_vector<4>(acOp.getIndices());
      Value oldIndex = indices.back();
      Type indexType = oldIndex.getType();

      int ratio = dstNumBytes / srcNumBytes;
      auto ratioValue = rewriter.create<spirv::ConstantOp>(
          loc, indexType, rewriter.getIntegerAttr(indexType, ratio));

      indices.back() =
          rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
      indices.push_back(
          rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));

      rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
          acOp, adaptor.getBasePtr(), indices);
      return success();
    }

    if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
        (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
      // The source indices are for a buffer with larger bitwidth scalar/vector
      // element types. Rewrite them into a buffer with smaller bitwidth element
      // types. We only need to scale the last index.
      int srcNumBytes = *srcElemType.getSizeInBytes();
      int dstNumBytes = *dstElemType.getSizeInBytes();
      assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);

      auto indices = llvm::to_vector<4>(acOp.getIndices());
      Value oldIndex = indices.back();
      Type indexType = oldIndex.getType();

      int ratio = srcNumBytes / dstNumBytes;
      auto ratioValue = rewriter.create<spirv::ConstantOp>(
          loc, indexType, rewriter.getIntegerAttr(indexType, ratio));

      indices.back() =
          rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);

      rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
          acOp, adaptor.getBasePtr(), indices);
      return success();
    }

    return rewriter.notifyMatchFailure(
        acOp, "unsupported src/dst types for spirv.AccessChain");
  }
};

struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
  using ConvertAliasResource::ConvertAliasResource;

  LogicalResult
  matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
    auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
    auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
    auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());

    Location loc = loadOp.getLoc();
    auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
    if (srcElemType == dstElemType) {
      rewriter.replaceOp(loadOp, newLoadOp->getResults());
      return success();
    }

    if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
      auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
                                                      newLoadOp.getValue());
      rewriter.replaceOp(loadOp, castOp->getResults());

      return success();
    }

    if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
        (isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
      // The source and destination have scalar types of different bitwidths, or
      // vector types of different component counts. For such cases, we load
      // multiple smaller bitwidth values and construct a larger bitwidth one.

      int srcNumBytes = *srcElemType.getSizeInBytes();
      int dstNumBytes = *dstElemType.getSizeInBytes();
      assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
      int ratio = srcNumBytes / dstNumBytes;
      if (ratio > 4)
        return rewriter.notifyMatchFailure(loadOp, "more than 4 components");

      SmallVector<Value> components;
      components.reserve(ratio);
      components.push_back(newLoadOp);

      auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
      if (!acOp)
        return rewriter.notifyMatchFailure(loadOp, "ptr not spirv.AccessChain");

      auto i32Type = rewriter.getI32Type();
      Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
      auto indices = llvm::to_vector<4>(acOp.getIndices());
      for (int i = 1; i < ratio; ++i) {
        // Load all subsequent components belonging to this element.
        indices.back() = rewriter.create<spirv::IAddOp>(
            loc, i32Type, indices.back(), oneValue);
        auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
            loc, acOp.getBasePtr(), indices);
        // Assuming little endian, this reads lower-ordered bits of the number
        // to lower-numbered components of the vector.
        components.push_back(
            rewriter.create<spirv::LoadOp>(loc, componentAcOp));
      }

      // Create a vector of the components and then cast back to the larger
      // bitwidth element type. For spirv.bitcast, the lower-numbered components
      // of the vector map to lower-ordered bits of the larger bitwidth element
      // type.

      Type vectorType = srcElemType;
      if (!isa<VectorType>(srcElemType))
        vectorType = VectorType::get({ratio}, dstElemType);

      // If both the source and destination are vector types, we need to make
      // sure the scalar type is the same for composite construction later.
      if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
        if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
          if (srcElemVecType.getElementType() !=
              dstElemVecType.getElementType()) {
            int64_t count =
                dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);

            // Make sure not to create 1-element vectors, which are illegal in
            // SPIR-V.
            Type castType = srcElemVecType.getElementType();
            if (count > 1)
              castType = VectorType::get({count}, castType);

            for (Value &c : components)
              c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
          }
        }
      Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
          loc, vectorType, components);

      if (!isa<VectorType>(srcElemType))
        vectorValue =
            rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
      rewriter.replaceOp(loadOp, vectorValue);
      return success();
    }

    return rewriter.notifyMatchFailure(
        loadOp, "unsupported src/dst types for spirv.Load");
  }
};

struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
  using ConvertAliasResource::ConvertAliasResource;

  LogicalResult
  matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
                  ConversionPatternRewriter &rewriter) const override {
    auto srcElemType =
        cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
    auto dstElemType =
        cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
    if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
      return rewriter.notifyMatchFailure(storeOp, "not scalar type");
    if (!areSameBitwidthScalarType(srcElemType, dstElemType))
      return rewriter.notifyMatchFailure(storeOp, "different bitwidth");

    Location loc = storeOp.getLoc();
    Value value = adaptor.getValue();
    if (srcElemType != dstElemType)
      value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
    rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
                                                value, storeOp->getAttrs());
    return success();
  }
};

//===----------------------------------------------------------------------===//
// Pass
//===----------------------------------------------------------------------===//

namespace {
class UnifyAliasedResourcePass final
    : public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
          UnifyAliasedResourcePass> {
public:
  explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)
      : getTargetEnvFn(std::move(getTargetEnv)) {}

  void runOnOperation() override;

private:
  spirv::GetTargetEnvFn getTargetEnvFn;
};

void UnifyAliasedResourcePass::runOnOperation() {
  spirv::ModuleOp moduleOp = getOperation();
  MLIRContext *context = &getContext();

  if (getTargetEnvFn) {
    // This pass is only needed for targeting WebGPU, Metal, or layering
    // Vulkan on Metal via MoltenVK, where we need to translate SPIR-V into
    // WGSL or MSL. The translation has limitations.
    spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
    spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
    bool isVulkanOnAppleDevices =
        clientAPI == spirv::ClientAPI::Vulkan &&
        targetEnv.getVendorID() == spirv::Vendor::Apple;
    if (clientAPI != spirv::ClientAPI::WebGPU &&
        clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
      return;
  }

  // Analyze aliased resources first.
  ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();

  ConversionTarget target(*context);
  target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
                               spirv::AccessChainOp, spirv::LoadOp,
                               spirv::StoreOp>(
      [&analysis](Operation *op) { return !analysis.shouldUnify(op); });
  target.addLegalDialect<spirv::SPIRVDialect>();

  // Run patterns to rewrite usages of non-canonical resources.
  RewritePatternSet patterns(context);
  patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
               ConvertLoad, ConvertStore>(analysis, context);
  if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
    return signalPassFailure();

  // Drop aliased attribute if we only have one single bound resource for a
  // descriptor. We need to re-collect the map here given in the above the
  // conversion is best effort; certain sets may not be converted.
  AliasedResourceMap resourceMap =
      collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
  for (const auto &dr : resourceMap) {
    const auto &resources = dr.second;
    if (resources.size() == 1)
      resources.front()->removeAttr("aliased");
  }
}
} // namespace

std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) {
  return std::make_unique<UnifyAliasedResourcePass>(std::move(getTargetEnv));
}