#include "mlir/Dialect/SPIRV/Transforms/Passes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include <algorithm>
#include <iterator>
namespace mlir {
namespace spirv {
#define GEN_PASS_DEF_SPIRVUNIFYALIASEDRESOURCEPASS
#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
}
}
#define DEBUG_TYPE "spirv-unify-aliased-resource"
using namespace mlir;
using Descriptor = std::pair<uint32_t, uint32_t>;
using AliasedResourceMap =
DenseMap<Descriptor, SmallVector<spirv::GlobalVariableOp>>;
static AliasedResourceMap collectAliasedResources(spirv::ModuleOp moduleOp) {
AliasedResourceMap aliasedResources;
moduleOp->walk([&aliasedResources](spirv::GlobalVariableOp varOp) {
if (varOp->getAttrOfType<UnitAttr>("aliased")) {
std::optional<uint32_t> set = varOp.getDescriptorSet();
std::optional<uint32_t> binding = varOp.getBinding();
if (set && binding)
aliasedResources[{*set, *binding}].push_back(varOp);
}
});
return aliasedResources;
}
static Type getRuntimeArrayElementType(Type type) {
auto ptrType = dyn_cast<spirv::PointerType>(type);
if (!ptrType)
return {};
auto structType = dyn_cast<spirv::StructType>(ptrType.getPointeeType());
if (!structType || structType.getNumElements() != 1)
return {};
auto rtArrayType =
dyn_cast<spirv::RuntimeArrayType>(structType.getElementType(0));
if (!rtArrayType)
return {};
return rtArrayType.getElementType();
}
static std::optional<int>
deduceCanonicalResource(ArrayRef<spirv::SPIRVType> types) {
SmallVector<int> scalarNumBits, vectorNumBits, vectorIndices;
scalarNumBits.reserve(types.size());
vectorNumBits.reserve(types.size());
vectorIndices.reserve(types.size());
for (const auto &indexedTypes : llvm::enumerate(types)) {
spirv::SPIRVType type = indexedTypes.value();
assert(type.isScalarOrVector());
if (auto vectorType = dyn_cast<VectorType>(type)) {
if (vectorType.getNumElements() % 2 != 0)
return std::nullopt;
std::optional<int64_t> numBytes = type.getSizeInBytes();
if (!numBytes)
return std::nullopt;
scalarNumBits.push_back(
vectorType.getElementType().getIntOrFloatBitWidth());
vectorNumBits.push_back(*numBytes * 8);
vectorIndices.push_back(indexedTypes.index());
} else {
scalarNumBits.push_back(type.getIntOrFloatBitWidth());
}
}
if (!vectorNumBits.empty()) {
auto *minVal = llvm::min_element(vectorNumBits);
if (llvm::any_of(vectorNumBits,
[&](int bits) { return bits % *minVal != 0; }))
return std::nullopt;
int index = vectorIndices[std::distance(vectorNumBits.begin(), minVal)];
int baseNumBits = scalarNumBits[index];
if (llvm::any_of(scalarNumBits,
[&](int bits) { return bits % baseNumBits != 0; }))
return std::nullopt;
return index;
}
auto *minVal = llvm::min_element(scalarNumBits);
if (llvm::any_of(scalarNumBits,
[minVal](int64_t bit) { return bit % *minVal != 0; }))
return std::nullopt;
return std::distance(scalarNumBits.begin(), minVal);
}
static bool areSameBitwidthScalarType(Type a, Type b) {
return a.isIntOrFloat() && b.isIntOrFloat() &&
a.getIntOrFloatBitWidth() == b.getIntOrFloatBitWidth();
}
namespace {
class ResourceAliasAnalysis {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ResourceAliasAnalysis)
explicit ResourceAliasAnalysis(Operation *);
bool shouldUnify(Operation *op) const;
const AliasedResourceMap &getResourceMap() const { return resourceMap; }
spirv::GlobalVariableOp
getCanonicalResource(const Descriptor &descriptor) const;
spirv::GlobalVariableOp
getCanonicalResource(spirv::GlobalVariableOp varOp) const;
spirv::SPIRVType getElementType(spirv::GlobalVariableOp varOp) const;
private:
void recordIfUnifiable(const Descriptor &descriptor,
ArrayRef<spirv::GlobalVariableOp> resources);
AliasedResourceMap resourceMap;
DenseMap<Descriptor, spirv::GlobalVariableOp> canonicalResourceMap;
DenseMap<spirv::GlobalVariableOp, Descriptor> descriptorMap;
DenseMap<spirv::GlobalVariableOp, spirv::SPIRVType> elementTypeMap;
};
}
ResourceAliasAnalysis::ResourceAliasAnalysis(Operation *root) {
AliasedResourceMap aliasedResources =
collectAliasedResources(cast<spirv::ModuleOp>(root));
for (const auto &descriptorResource : aliasedResources) {
recordIfUnifiable(descriptorResource.first, descriptorResource.second);
}
}
bool ResourceAliasAnalysis::shouldUnify(Operation *op) const {
if (!op)
return false;
if (auto varOp = dyn_cast<spirv::GlobalVariableOp>(op)) {
auto canonicalOp = getCanonicalResource(varOp);
return canonicalOp && varOp != canonicalOp;
}
if (auto addressOp = dyn_cast<spirv::AddressOfOp>(op)) {
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
auto *varOp =
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable());
return shouldUnify(varOp);
}
if (auto acOp = dyn_cast<spirv::AccessChainOp>(op))
return shouldUnify(acOp.getBasePtr().getDefiningOp());
if (auto loadOp = dyn_cast<spirv::LoadOp>(op))
return shouldUnify(loadOp.getPtr().getDefiningOp());
if (auto storeOp = dyn_cast<spirv::StoreOp>(op))
return shouldUnify(storeOp.getPtr().getDefiningOp());
return false;
}
spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
const Descriptor &descriptor) const {
auto varIt = canonicalResourceMap.find(descriptor);
if (varIt == canonicalResourceMap.end())
return {};
return varIt->second;
}
spirv::GlobalVariableOp ResourceAliasAnalysis::getCanonicalResource(
spirv::GlobalVariableOp varOp) const {
auto descriptorIt = descriptorMap.find(varOp);
if (descriptorIt == descriptorMap.end())
return {};
return getCanonicalResource(descriptorIt->second);
}
spirv::SPIRVType
ResourceAliasAnalysis::getElementType(spirv::GlobalVariableOp varOp) const {
auto it = elementTypeMap.find(varOp);
if (it == elementTypeMap.end())
return {};
return it->second;
}
void ResourceAliasAnalysis::recordIfUnifiable(
const Descriptor &descriptor, ArrayRef<spirv::GlobalVariableOp> resources) {
SmallVector<spirv::SPIRVType> elementTypes;
for (spirv::GlobalVariableOp resource : resources) {
Type elementType = getRuntimeArrayElementType(resource.getType());
if (!elementType)
return;
auto type = cast<spirv::SPIRVType>(elementType);
if (!type.isScalarOrVector())
return;
elementTypes.push_back(type);
}
std::optional<int> index = deduceCanonicalResource(elementTypes);
if (!index)
return;
resourceMap[descriptor].assign(resources.begin(), resources.end());
canonicalResourceMap[descriptor] = resources[*index];
for (const auto &resource : llvm::enumerate(resources)) {
descriptorMap[resource.value()] = descriptor;
elementTypeMap[resource.value()] = elementTypes[resource.index()];
}
}
template <typename OpTy>
class ConvertAliasResource : public OpConversionPattern<OpTy> {
public:
ConvertAliasResource(const ResourceAliasAnalysis &analysis,
MLIRContext *context, PatternBenefit benefit = 1)
: OpConversionPattern<OpTy>(context, benefit), analysis(analysis) {}
protected:
const ResourceAliasAnalysis &analysis;
};
struct ConvertVariable : public ConvertAliasResource<spirv::GlobalVariableOp> {
using ConvertAliasResource::ConvertAliasResource;
LogicalResult
matchAndRewrite(spirv::GlobalVariableOp varOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(varOp);
return success();
}
};
struct ConvertAddressOf : public ConvertAliasResource<spirv::AddressOfOp> {
using ConvertAliasResource::ConvertAliasResource;
LogicalResult
matchAndRewrite(spirv::AddressOfOp addressOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto moduleOp = addressOp->getParentOfType<spirv::ModuleOp>();
auto srcVarOp = cast<spirv::GlobalVariableOp>(
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
rewriter.replaceOpWithNewOp<spirv::AddressOfOp>(addressOp, dstVarOp);
return success();
}
};
struct ConvertAccessChain : public ConvertAliasResource<spirv::AccessChainOp> {
using ConvertAliasResource::ConvertAliasResource;
LogicalResult
matchAndRewrite(spirv::AccessChainOp acOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto addressOp = acOp.getBasePtr().getDefiningOp<spirv::AddressOfOp>();
if (!addressOp)
return rewriter.notifyMatchFailure(acOp, "base ptr not addressof op");
auto moduleOp = acOp->getParentOfType<spirv::ModuleOp>();
auto srcVarOp = cast<spirv::GlobalVariableOp>(
SymbolTable::lookupSymbolIn(moduleOp, addressOp.getVariable()));
auto dstVarOp = analysis.getCanonicalResource(srcVarOp);
spirv::SPIRVType srcElemType = analysis.getElementType(srcVarOp);
spirv::SPIRVType dstElemType = analysis.getElementType(dstVarOp);
if (srcElemType == dstElemType ||
areSameBitwidthScalarType(srcElemType, dstElemType)) {
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), adaptor.getIndices());
return success();
}
Location loc = acOp.getLoc();
if (srcElemType.isIntOrFloat() && isa<VectorType>(dstElemType)) {
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(dstNumBytes >= srcNumBytes && dstNumBytes % srcNumBytes == 0);
auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
Type indexType = oldIndex.getType();
int ratio = dstNumBytes / srcNumBytes;
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
indices.back() =
rewriter.create<spirv::SDivOp>(loc, indexType, oldIndex, ratioValue);
indices.push_back(
rewriter.create<spirv::SModOp>(loc, indexType, oldIndex, ratioValue));
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
return success();
}
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
(isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(srcNumBytes >= dstNumBytes && srcNumBytes % dstNumBytes == 0);
auto indices = llvm::to_vector<4>(acOp.getIndices());
Value oldIndex = indices.back();
Type indexType = oldIndex.getType();
int ratio = srcNumBytes / dstNumBytes;
auto ratioValue = rewriter.create<spirv::ConstantOp>(
loc, indexType, rewriter.getIntegerAttr(indexType, ratio));
indices.back() =
rewriter.create<spirv::IMulOp>(loc, indexType, oldIndex, ratioValue);
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
acOp, adaptor.getBasePtr(), indices);
return success();
}
return rewriter.notifyMatchFailure(
acOp, "unsupported src/dst types for spirv.AccessChain");
}
};
struct ConvertLoad : public ConvertAliasResource<spirv::LoadOp> {
using ConvertAliasResource::ConvertAliasResource;
LogicalResult
matchAndRewrite(spirv::LoadOp loadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcPtrType = cast<spirv::PointerType>(loadOp.getPtr().getType());
auto srcElemType = cast<spirv::SPIRVType>(srcPtrType.getPointeeType());
auto dstPtrType = cast<spirv::PointerType>(adaptor.getPtr().getType());
auto dstElemType = cast<spirv::SPIRVType>(dstPtrType.getPointeeType());
Location loc = loadOp.getLoc();
auto newLoadOp = rewriter.create<spirv::LoadOp>(loc, adaptor.getPtr());
if (srcElemType == dstElemType) {
rewriter.replaceOp(loadOp, newLoadOp->getResults());
return success();
}
if (areSameBitwidthScalarType(srcElemType, dstElemType)) {
auto castOp = rewriter.create<spirv::BitcastOp>(loc, srcElemType,
newLoadOp.getValue());
rewriter.replaceOp(loadOp, castOp->getResults());
return success();
}
if ((srcElemType.isIntOrFloat() && dstElemType.isIntOrFloat()) ||
(isa<VectorType>(srcElemType) && isa<VectorType>(dstElemType))) {
int srcNumBytes = *srcElemType.getSizeInBytes();
int dstNumBytes = *dstElemType.getSizeInBytes();
assert(srcNumBytes > dstNumBytes && srcNumBytes % dstNumBytes == 0);
int ratio = srcNumBytes / dstNumBytes;
if (ratio > 4)
return rewriter.notifyMatchFailure(loadOp, "more than 4 components");
SmallVector<Value> components;
components.reserve(ratio);
components.push_back(newLoadOp);
auto acOp = adaptor.getPtr().getDefiningOp<spirv::AccessChainOp>();
if (!acOp)
return rewriter.notifyMatchFailure(loadOp, "ptr not spirv.AccessChain");
auto i32Type = rewriter.getI32Type();
Value oneValue = spirv::ConstantOp::getOne(i32Type, loc, rewriter);
auto indices = llvm::to_vector<4>(acOp.getIndices());
for (int i = 1; i < ratio; ++i) {
indices.back() = rewriter.create<spirv::IAddOp>(
loc, i32Type, indices.back(), oneValue);
auto componentAcOp = rewriter.create<spirv::AccessChainOp>(
loc, acOp.getBasePtr(), indices);
components.push_back(
rewriter.create<spirv::LoadOp>(loc, componentAcOp));
}
Type vectorType = srcElemType;
if (!isa<VectorType>(srcElemType))
vectorType = VectorType::get({ratio}, dstElemType);
if (auto srcElemVecType = dyn_cast<VectorType>(srcElemType))
if (auto dstElemVecType = dyn_cast<VectorType>(dstElemType)) {
if (srcElemVecType.getElementType() !=
dstElemVecType.getElementType()) {
int64_t count =
dstNumBytes / (srcElemVecType.getElementTypeBitWidth() / 8);
Type castType = srcElemVecType.getElementType();
if (count > 1)
castType = VectorType::get({count}, castType);
for (Value &c : components)
c = rewriter.create<spirv::BitcastOp>(loc, castType, c);
}
}
Value vectorValue = rewriter.create<spirv::CompositeConstructOp>(
loc, vectorType, components);
if (!isa<VectorType>(srcElemType))
vectorValue =
rewriter.create<spirv::BitcastOp>(loc, srcElemType, vectorValue);
rewriter.replaceOp(loadOp, vectorValue);
return success();
}
return rewriter.notifyMatchFailure(
loadOp, "unsupported src/dst types for spirv.Load");
}
};
struct ConvertStore : public ConvertAliasResource<spirv::StoreOp> {
using ConvertAliasResource::ConvertAliasResource;
LogicalResult
matchAndRewrite(spirv::StoreOp storeOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcElemType =
cast<spirv::PointerType>(storeOp.getPtr().getType()).getPointeeType();
auto dstElemType =
cast<spirv::PointerType>(adaptor.getPtr().getType()).getPointeeType();
if (!srcElemType.isIntOrFloat() || !dstElemType.isIntOrFloat())
return rewriter.notifyMatchFailure(storeOp, "not scalar type");
if (!areSameBitwidthScalarType(srcElemType, dstElemType))
return rewriter.notifyMatchFailure(storeOp, "different bitwidth");
Location loc = storeOp.getLoc();
Value value = adaptor.getValue();
if (srcElemType != dstElemType)
value = rewriter.create<spirv::BitcastOp>(loc, dstElemType, value);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, adaptor.getPtr(),
value, storeOp->getAttrs());
return success();
}
};
namespace {
class UnifyAliasedResourcePass final
: public spirv::impl::SPIRVUnifyAliasedResourcePassBase<
UnifyAliasedResourcePass> {
public:
explicit UnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv)
: getTargetEnvFn(std::move(getTargetEnv)) {}
void runOnOperation() override;
private:
spirv::GetTargetEnvFn getTargetEnvFn;
};
void UnifyAliasedResourcePass::runOnOperation() {
spirv::ModuleOp moduleOp = getOperation();
MLIRContext *context = &getContext();
if (getTargetEnvFn) {
spirv::TargetEnvAttr targetEnv = getTargetEnvFn(moduleOp);
spirv::ClientAPI clientAPI = targetEnv.getClientAPI();
bool isVulkanOnAppleDevices =
clientAPI == spirv::ClientAPI::Vulkan &&
targetEnv.getVendorID() == spirv::Vendor::Apple;
if (clientAPI != spirv::ClientAPI::WebGPU &&
clientAPI != spirv::ClientAPI::Metal && !isVulkanOnAppleDevices)
return;
}
ResourceAliasAnalysis &analysis = getAnalysis<ResourceAliasAnalysis>();
ConversionTarget target(*context);
target.addDynamicallyLegalOp<spirv::GlobalVariableOp, spirv::AddressOfOp,
spirv::AccessChainOp, spirv::LoadOp,
spirv::StoreOp>(
[&analysis](Operation *op) { return !analysis.shouldUnify(op); });
target.addLegalDialect<spirv::SPIRVDialect>();
RewritePatternSet patterns(context);
patterns.add<ConvertVariable, ConvertAddressOf, ConvertAccessChain,
ConvertLoad, ConvertStore>(analysis, context);
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
return signalPassFailure();
AliasedResourceMap resourceMap =
collectAliasedResources(cast<spirv::ModuleOp>(moduleOp));
for (const auto &dr : resourceMap) {
const auto &resources = dr.second;
if (resources.size() == 1)
resources.front()->removeAttr("aliased");
}
}
}
std::unique_ptr<mlir::OperationPass<spirv::ModuleOp>>
spirv::createUnifyAliasedResourcePass(spirv::GetTargetEnvFn getTargetEnv) {
return std::make_unique<UnifyAliasedResourcePass>(std::move(getTargetEnv));
}