#include "Serializer.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/IR/RegionGraphTraits.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/DepthFirstIterator.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "spirv-serialization"
using namespace mlir;
static LogicalResult
visitInPrettyBlockOrder(Block *headerBlock,
function_ref<LogicalResult(Block *)> blockHandler,
bool skipHeader = false, BlockRange skipBlocks = {}) {
llvm::df_iterator_default_set<Block *, 4> doneBlocks;
doneBlocks.insert(skipBlocks.begin(), skipBlocks.end());
for (Block *block : llvm::depth_first_ext(headerBlock, doneBlocks)) {
if (skipHeader && block == headerBlock)
continue;
if (failed(blockHandler(block)))
return failure();
}
return success();
}
namespace mlir {
namespace spirv {
LogicalResult Serializer::processConstantOp(spirv::ConstantOp op) {
if (auto resultID =
prepareConstant(op.getLoc(), op.getType(), op.getValue())) {
valueIDMap[op.getResult()] = resultID;
return success();
}
return failure();
}
LogicalResult Serializer::processSpecConstantOp(spirv::SpecConstantOp op) {
if (auto resultID = prepareConstantScalar(op.getLoc(), op.getDefaultValue(),
true)) {
if (auto specID = op->getAttrOfType<IntegerAttr>("spec_id")) {
auto val = static_cast<uint32_t>(specID.getInt());
if (failed(emitDecoration(resultID, spirv::Decoration::SpecId, {val})))
return failure();
}
specConstIDMap[op.getSymName()] = resultID;
return processName(resultID, op.getSymName());
}
return failure();
}
LogicalResult
Serializer::processSpecConstantCompositeOp(spirv::SpecConstantCompositeOp op) {
uint32_t typeID = 0;
if (failed(processType(op.getLoc(), op.getType(), typeID))) {
return failure();
}
auto resultID = getNextID();
SmallVector<uint32_t, 8> operands;
operands.push_back(typeID);
operands.push_back(resultID);
auto constituents = op.getConstituents();
for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
auto constituent = dyn_cast<FlatSymbolRefAttr>(constituents[index]);
auto constituentName = constituent.getValue();
auto constituentID = getSpecConstID(constituentName);
if (!constituentID) {
return op.emitError("unknown result <id> for specialization constant ")
<< constituentName;
}
operands.push_back(constituentID);
}
encodeInstructionInto(typesGlobalValues,
spirv::Opcode::OpSpecConstantComposite, operands);
specConstIDMap[op.getSymName()] = resultID;
return processName(resultID, op.getSymName());
}
LogicalResult
Serializer::processSpecConstantOperationOp(spirv::SpecConstantOperationOp op) {
uint32_t typeID = 0;
if (failed(processType(op.getLoc(), op.getType(), typeID))) {
return failure();
}
auto resultID = getNextID();
SmallVector<uint32_t, 8> operands;
operands.push_back(typeID);
operands.push_back(resultID);
Block &block = op.getRegion().getBlocks().front();
Operation &enclosedOp = block.getOperations().front();
std::string enclosedOpName;
llvm::raw_string_ostream rss(enclosedOpName);
rss << "Op" << enclosedOp.getName().stripDialect();
auto enclosedOpcode = spirv::symbolizeOpcode(rss.str());
if (!enclosedOpcode) {
op.emitError("Couldn't find op code for op ")
<< enclosedOp.getName().getStringRef();
return failure();
}
operands.push_back(static_cast<uint32_t>(*enclosedOpcode));
for (Value operand : enclosedOp.getOperands()) {
uint32_t id = getValueID(operand);
assert(id && "use before def!");
operands.push_back(id);
}
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpSpecConstantOp,
operands);
valueIDMap[op.getResult()] = resultID;
return success();
}
LogicalResult Serializer::processUndefOp(spirv::UndefOp op) {
auto undefType = op.getType();
auto &id = undefValIDMap[undefType];
if (!id) {
id = getNextID();
uint32_t typeID = 0;
if (failed(processType(op.getLoc(), undefType, typeID)))
return failure();
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpUndef,
{typeID, id});
}
valueIDMap[op.getResult()] = id;
return success();
}
LogicalResult Serializer::processFuncParameter(spirv::FuncOp op) {
for (auto [idx, arg] : llvm::enumerate(op.getArguments())) {
uint32_t argTypeID = 0;
if (failed(processType(op.getLoc(), arg.getType(), argTypeID))) {
return failure();
}
auto argValueID = getNextID();
auto funcOp = cast<FunctionOpInterface>(*op);
for (auto argAttr : funcOp.getArgAttrs(idx)) {
if (argAttr.getName() != DecorationAttr::name)
continue;
if (auto decAttr = dyn_cast<DecorationAttr>(argAttr.getValue())) {
if (failed(processDecorationAttr(op->getLoc(), argValueID,
decAttr.getValue(), decAttr)))
return failure();
}
}
valueIDMap[arg] = argValueID;
encodeInstructionInto(functionHeader, spirv::Opcode::OpFunctionParameter,
{argTypeID, argValueID});
}
return success();
}
LogicalResult Serializer::processFuncOp(spirv::FuncOp op) {
LLVM_DEBUG(llvm::dbgs() << "-- start function '" << op.getName() << "' --\n");
assert(functionHeader.empty() && functionBody.empty());
uint32_t fnTypeID = 0;
if (failed(processType(op.getLoc(), op.getFunctionType(), fnTypeID)))
return failure();
SmallVector<uint32_t, 4> operands;
uint32_t resTypeID = 0;
auto resultTypes = op.getFunctionType().getResults();
if (resultTypes.size() > 1) {
return op.emitError("cannot serialize function with multiple return types");
}
if (failed(processType(op.getLoc(),
(resultTypes.empty() ? getVoidType() : resultTypes[0]),
resTypeID))) {
return failure();
}
operands.push_back(resTypeID);
auto funcID = getOrCreateFunctionID(op.getName());
operands.push_back(funcID);
operands.push_back(static_cast<uint32_t>(op.getFunctionControl()));
operands.push_back(fnTypeID);
encodeInstructionInto(functionHeader, spirv::Opcode::OpFunction, operands);
if (failed(processName(funcID, op.getName()))) {
return failure();
}
auto linkageAttr = op.getLinkageAttributes();
auto hasImportLinkage =
linkageAttr && (linkageAttr.value().getLinkageType().getValue() ==
spirv::LinkageType::Import);
if (op.isExternal() && !hasImportLinkage) {
return op.emitError(
"'spirv.module' cannot contain external functions "
"without 'Import' linkage_attributes (LinkageAttributes)");
}
if (op.isExternal() && hasImportLinkage) {
op.addEntryBlock();
if (failed(processFuncParameter(op)))
return failure();
op.eraseBody();
} else {
if (failed(processFuncParameter(op)))
return failure();
encodeInstructionInto(functionHeader, spirv::Opcode::OpLabel,
{getOrCreateBlockID(&op.front())});
if (failed(processBlock(&op.front(), true)))
return failure();
if (failed(visitInPrettyBlockOrder(
&op.front(), [&](Block *block) { return processBlock(block); },
true))) {
return failure();
}
for (const auto &deferredValue : deferredPhiValues) {
Value value = deferredValue.first;
uint32_t id = getValueID(value);
LLVM_DEBUG(llvm::dbgs() << "[phi] fix reference of value " << value
<< " to id = " << id << '\n');
assert(id && "OpPhi references undefined value!");
for (size_t offset : deferredValue.second)
functionBody[offset] = id;
}
deferredPhiValues.clear();
}
LLVM_DEBUG(llvm::dbgs() << "-- completed function '" << op.getName()
<< "' --\n");
for (auto attr : op->getAttrs()) {
auto isValidDecoration = mlir::spirv::symbolizeEnum<spirv::Decoration>(
llvm::convertToCamelFromSnakeCase(attr.getName().strref(),
true));
if (isValidDecoration != std::nullopt) {
if (failed(processDecoration(op.getLoc(), funcID, attr))) {
return failure();
}
}
}
encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionEnd, {});
functions.append(functionHeader.begin(), functionHeader.end());
functions.append(functionBody.begin(), functionBody.end());
functionHeader.clear();
functionBody.clear();
return success();
}
LogicalResult Serializer::processVariableOp(spirv::VariableOp op) {
SmallVector<uint32_t, 4> operands;
SmallVector<StringRef, 2> elidedAttrs;
uint32_t resultID = 0;
uint32_t resultTypeID = 0;
if (failed(processType(op.getLoc(), op.getType(), resultTypeID))) {
return failure();
}
operands.push_back(resultTypeID);
resultID = getNextID();
valueIDMap[op.getResult()] = resultID;
operands.push_back(resultID);
auto attr = op->getAttr(spirv::attributeName<spirv::StorageClass>());
if (attr) {
operands.push_back(
static_cast<uint32_t>(cast<spirv::StorageClassAttr>(attr).getValue()));
}
elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
for (auto arg : op.getODSOperands(0)) {
auto argID = getValueID(arg);
if (!argID) {
return emitError(op.getLoc(), "operand 0 has a use before def");
}
operands.push_back(argID);
}
if (failed(emitDebugLine(functionHeader, op.getLoc())))
return failure();
encodeInstructionInto(functionHeader, spirv::Opcode::OpVariable, operands);
for (auto attr : op->getAttrs()) {
if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
return attr.getName() == elided;
})) {
continue;
}
if (failed(processDecoration(op.getLoc(), resultID, attr))) {
return failure();
}
}
return success();
}
LogicalResult
Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
uint32_t resultTypeID = 0;
SmallVector<StringRef, 4> elidedAttrs;
if (failed(processType(varOp.getLoc(), varOp.getType(), resultTypeID))) {
return failure();
}
elidedAttrs.push_back("type");
SmallVector<uint32_t, 4> operands;
operands.push_back(resultTypeID);
auto resultID = getNextID();
auto varName = varOp.getSymName();
elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
if (failed(processName(resultID, varName))) {
return failure();
}
globalVarIDMap[varName] = resultID;
operands.push_back(resultID);
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
StringRef initAttrName = varOp.getInitializerAttrName().getValue();
if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
uint32_t initializerID = 0;
auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
varOp->getParentOp(), initRef.getAttr());
if (isa<spirv::GlobalVariableOp>(initOp))
initializerID = getVariableID(*initSymbolName);
else
initializerID = getSpecConstID(*initSymbolName);
if (!initializerID)
return emitError(varOp.getLoc(),
"invalid usage of undefined variable as initializer");
operands.push_back(initializerID);
elidedAttrs.push_back(initAttrName);
}
if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
return failure();
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
elidedAttrs.push_back(initAttrName);
for (auto attr : varOp->getAttrs()) {
if (llvm::any_of(elidedAttrs, [&](StringRef elided) {
return attr.getName() == elided;
})) {
continue;
}
if (failed(processDecoration(varOp.getLoc(), resultID, attr))) {
return failure();
}
}
return success();
}
LogicalResult Serializer::processSelectionOp(spirv::SelectionOp selectionOp) {
auto &body = selectionOp.getBody();
for (Block &block : body)
getOrCreateBlockID(&block);
auto *headerBlock = selectionOp.getHeaderBlock();
auto *mergeBlock = selectionOp.getMergeBlock();
auto headerID = getBlockID(headerBlock);
auto mergeID = getBlockID(mergeBlock);
auto loc = selectionOp.getLoc();
encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
auto emitSelectionMerge = [&]() {
if (failed(emitDebugLine(functionBody, loc)))
return failure();
lastProcessedWasMergeInst = true;
encodeInstructionInto(
functionBody, spirv::Opcode::OpSelectionMerge,
{mergeID, static_cast<uint32_t>(selectionOp.getSelectionControl())});
return success();
};
if (failed(
processBlock(headerBlock, false, emitSelectionMerge)))
return failure();
if (failed(visitInPrettyBlockOrder(
headerBlock, [&](Block *block) { return processBlock(block); },
true, {mergeBlock})))
return failure();
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
LLVM_DEBUG(llvm::dbgs() << "done merge ");
LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
return success();
}
LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
auto &body = loopOp.getBody();
for (Block &block : llvm::drop_begin(body))
getOrCreateBlockID(&block);
auto *headerBlock = loopOp.getHeaderBlock();
auto *continueBlock = loopOp.getContinueBlock();
auto *mergeBlock = loopOp.getMergeBlock();
auto headerID = getBlockID(headerBlock);
auto continueID = getBlockID(continueBlock);
auto mergeID = getBlockID(mergeBlock);
auto loc = loopOp.getLoc();
encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {headerID});
auto emitLoopMerge = [&]() {
if (failed(emitDebugLine(functionBody, loc)))
return failure();
lastProcessedWasMergeInst = true;
encodeInstructionInto(
functionBody, spirv::Opcode::OpLoopMerge,
{mergeID, continueID, static_cast<uint32_t>(loopOp.getLoopControl())});
return success();
};
if (failed(processBlock(headerBlock, false, emitLoopMerge)))
return failure();
if (failed(visitInPrettyBlockOrder(
headerBlock, [&](Block *block) { return processBlock(block); },
true, {continueBlock, mergeBlock})))
return failure();
if (failed(processBlock(continueBlock)))
return failure();
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {mergeID});
LLVM_DEBUG(llvm::dbgs() << "done merge ");
LLVM_DEBUG(printBlock(mergeBlock, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << "\n");
return success();
}
LogicalResult Serializer::processBranchConditionalOp(
spirv::BranchConditionalOp condBranchOp) {
auto conditionID = getValueID(condBranchOp.getCondition());
auto trueLabelID = getOrCreateBlockID(condBranchOp.getTrueBlock());
auto falseLabelID = getOrCreateBlockID(condBranchOp.getFalseBlock());
SmallVector<uint32_t, 5> arguments{conditionID, trueLabelID, falseLabelID};
if (auto weights = condBranchOp.getBranchWeights()) {
for (auto val : weights->getValue())
arguments.push_back(cast<IntegerAttr>(val).getInt());
}
if (failed(emitDebugLine(functionBody, condBranchOp.getLoc())))
return failure();
encodeInstructionInto(functionBody, spirv::Opcode::OpBranchConditional,
arguments);
return success();
}
LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
if (failed(emitDebugLine(functionBody, branchOp.getLoc())))
return failure();
encodeInstructionInto(functionBody, spirv::Opcode::OpBranch,
{getOrCreateBlockID(branchOp.getTarget())});
return success();
}
LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
auto varName = addressOfOp.getVariable();
auto variableID = getVariableID(varName);
if (!variableID) {
return addressOfOp.emitError("unknown result <id> for variable ")
<< varName;
}
valueIDMap[addressOfOp.getPointer()] = variableID;
return success();
}
LogicalResult
Serializer::processReferenceOfOp(spirv::ReferenceOfOp referenceOfOp) {
auto constName = referenceOfOp.getSpecConst();
auto constID = getSpecConstID(constName);
if (!constID) {
return referenceOfOp.emitError(
"unknown result <id> for specialization constant ")
<< constName;
}
valueIDMap[referenceOfOp.getReference()] = constID;
return success();
}
template <>
LogicalResult
Serializer::processOp<spirv::EntryPointOp>(spirv::EntryPointOp op) {
SmallVector<uint32_t, 4> operands;
operands.push_back(static_cast<uint32_t>(op.getExecutionModel()));
auto funcID = getFunctionID(op.getFn());
if (!funcID) {
return op.emitError("missing <id> for function ")
<< op.getFn()
<< "; function needs to be defined before spirv.EntryPoint is "
"serialized";
}
operands.push_back(funcID);
spirv::encodeStringLiteralInto(operands, op.getFn());
if (auto interface = op.getInterface()) {
for (auto var : interface.getValue()) {
auto id = getVariableID(cast<FlatSymbolRefAttr>(var).getValue());
if (!id) {
return op.emitError(
"referencing undefined global variable."
"spirv.EntryPoint is at the end of spirv.module. All "
"referenced variables should already be defined");
}
operands.push_back(id);
}
}
encodeInstructionInto(entryPoints, spirv::Opcode::OpEntryPoint, operands);
return success();
}
template <>
LogicalResult
Serializer::processOp<spirv::ExecutionModeOp>(spirv::ExecutionModeOp op) {
SmallVector<uint32_t, 4> operands;
auto funcID = getFunctionID(op.getFn());
if (!funcID) {
return op.emitError("missing <id> for function ")
<< op.getFn()
<< "; function needs to be serialized before ExecutionModeOp is "
"serialized";
}
operands.push_back(funcID);
operands.push_back(static_cast<uint32_t>(op.getExecutionMode()));
auto values = op.getValues();
if (values) {
for (auto &intVal : values.getValue()) {
operands.push_back(static_cast<uint32_t>(
llvm::cast<IntegerAttr>(intVal).getValue().getZExtValue()));
}
}
encodeInstructionInto(executionModes, spirv::Opcode::OpExecutionMode,
operands);
return success();
}
template <>
LogicalResult
Serializer::processOp<spirv::FunctionCallOp>(spirv::FunctionCallOp op) {
auto funcName = op.getCallee();
uint32_t resTypeID = 0;
Type resultTy = op.getNumResults() ? *op.result_type_begin() : getVoidType();
if (failed(processType(op.getLoc(), resultTy, resTypeID)))
return failure();
auto funcID = getOrCreateFunctionID(funcName);
auto funcCallID = getNextID();
SmallVector<uint32_t, 8> operands{resTypeID, funcCallID, funcID};
for (auto value : op.getArguments()) {
auto valueID = getValueID(value);
assert(valueID && "cannot find a value for spirv.FunctionCall");
operands.push_back(valueID);
}
if (!isa<NoneType>(resultTy))
valueIDMap[op.getResult(0)] = funcCallID;
encodeInstructionInto(functionBody, spirv::Opcode::OpFunctionCall, operands);
return success();
}
template <>
LogicalResult
Serializer::processOp<spirv::CopyMemoryOp>(spirv::CopyMemoryOp op) {
SmallVector<uint32_t, 4> operands;
SmallVector<StringRef, 2> elidedAttrs;
for (Value operand : op->getOperands()) {
auto id = getValueID(operand);
assert(id && "use before def!");
operands.push_back(id);
}
StringAttr memoryAccess = op.getMemoryAccessAttrName();
if (auto attr = op->getAttr(memoryAccess)) {
operands.push_back(
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
elidedAttrs.push_back(memoryAccess.strref());
StringAttr alignment = op.getAlignmentAttrName();
if (auto attr = op->getAttr(alignment)) {
operands.push_back(static_cast<uint32_t>(
cast<IntegerAttr>(attr).getValue().getZExtValue()));
}
elidedAttrs.push_back(alignment.strref());
StringAttr sourceMemoryAccess = op.getSourceMemoryAccessAttrName();
if (auto attr = op->getAttr(sourceMemoryAccess)) {
operands.push_back(
static_cast<uint32_t>(cast<spirv::MemoryAccessAttr>(attr).getValue()));
}
elidedAttrs.push_back(sourceMemoryAccess.strref());
StringAttr sourceAlignment = op.getSourceAlignmentAttrName();
if (auto attr = op->getAttr(sourceAlignment)) {
operands.push_back(static_cast<uint32_t>(
cast<IntegerAttr>(attr).getValue().getZExtValue()));
}
elidedAttrs.push_back(sourceAlignment.strref());
if (failed(emitDebugLine(functionBody, op.getLoc())))
return failure();
encodeInstructionInto(functionBody, spirv::Opcode::OpCopyMemory, operands);
return success();
}
template <>
LogicalResult Serializer::processOp<spirv::GenericCastToPtrExplicitOp>(
spirv::GenericCastToPtrExplicitOp op) {
SmallVector<uint32_t, 4> operands;
Type resultTy;
Location loc = op->getLoc();
uint32_t resultTypeID = 0;
uint32_t resultID = 0;
resultTy = op->getResult(0).getType();
if (failed(processType(loc, resultTy, resultTypeID)))
return failure();
operands.push_back(resultTypeID);
resultID = getNextID();
operands.push_back(resultID);
valueIDMap[op->getResult(0)] = resultID;
for (Value operand : op->getOperands())
operands.push_back(getValueID(operand));
spirv::StorageClass resultStorage =
cast<spirv::PointerType>(resultTy).getStorageClass();
operands.push_back(static_cast<uint32_t>(resultStorage));
encodeInstructionInto(functionBody, spirv::Opcode::OpGenericCastToPtrExplicit,
operands);
return success();
}
#define GET_SERIALIZATION_FNS
#include "mlir/Dialect/SPIRV/IR/SPIRVSerialization.inc"
}
}