#include "Serializer.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
#include "mlir/Target/SPIRV/SPIRVBinaryUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/ADT/bit.h"
#include "llvm/Support/Debug.h"
#include <cstdint>
#include <optional>
#define DEBUG_TYPE "spirv-serialization"
using namespace mlir;
static Block *getStructuredControlFlowOpMergeBlock(Operation *op) {
if (auto selectionOp = dyn_cast<spirv::SelectionOp>(op))
return selectionOp.getMergeBlock();
if (auto loopOp = dyn_cast<spirv::LoopOp>(op))
return loopOp.getMergeBlock();
return nullptr;
}
static Block *getPhiIncomingBlock(Block *block) {
if (block->isEntryBlock()) {
if (auto loopOp = dyn_cast<spirv::LoopOp>(block->getParentOp())) {
Operation *op = loopOp.getOperation();
while ((op = op->getPrevNode()) != nullptr)
if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(op))
return incomingBlock;
return loopOp->getBlock();
}
}
for (Operation &op : llvm::reverse(block->getOperations())) {
if (Block *incomingBlock = getStructuredControlFlowOpMergeBlock(&op))
return incomingBlock;
}
return block;
}
namespace mlir {
namespace spirv {
void encodeInstructionInto(SmallVectorImpl<uint32_t> &binary, spirv::Opcode op,
ArrayRef<uint32_t> operands) {
uint32_t wordCount = 1 + operands.size();
binary.push_back(spirv::getPrefixedOpcode(wordCount, op));
binary.append(operands.begin(), operands.end());
}
Serializer::Serializer(spirv::ModuleOp module,
const SerializationOptions &options)
: module(module), mlirBuilder(module.getContext()), options(options) {}
LogicalResult Serializer::serialize() {
LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
if (failed(module.verifyInvariants()))
return failure();
processCapability();
processExtension();
processMemoryModel();
processDebugInfo();
for (auto &op : *module.getBody()) {
if (failed(processOperation(&op))) {
return failure();
}
}
LLVM_DEBUG(llvm::dbgs() << "+++ completed serialization +++\n");
return success();
}
void Serializer::collect(SmallVectorImpl<uint32_t> &binary) {
auto moduleSize = spirv::kHeaderWordCount + capabilities.size() +
extensions.size() + extendedSets.size() +
memoryModel.size() + entryPoints.size() +
executionModes.size() + decorations.size() +
typesGlobalValues.size() + functions.size();
binary.clear();
binary.reserve(moduleSize);
spirv::appendModuleHeader(binary, module.getVceTriple()->getVersion(),
nextID);
binary.append(capabilities.begin(), capabilities.end());
binary.append(extensions.begin(), extensions.end());
binary.append(extendedSets.begin(), extendedSets.end());
binary.append(memoryModel.begin(), memoryModel.end());
binary.append(entryPoints.begin(), entryPoints.end());
binary.append(executionModes.begin(), executionModes.end());
binary.append(debug.begin(), debug.end());
binary.append(names.begin(), names.end());
binary.append(decorations.begin(), decorations.end());
binary.append(typesGlobalValues.begin(), typesGlobalValues.end());
binary.append(functions.begin(), functions.end());
}
#ifndef NDEBUG
void Serializer::printValueIDMap(raw_ostream &os) {
os << "\n= Value <id> Map =\n\n";
for (auto valueIDPair : valueIDMap) {
Value val = valueIDPair.first;
os << " " << val << " "
<< "id = " << valueIDPair.second << ' ';
if (auto *op = val.getDefiningOp()) {
os << "from op '" << op->getName() << "'";
} else if (auto arg = dyn_cast<BlockArgument>(val)) {
Block *block = arg.getOwner();
os << "from argument of block " << block << ' ';
os << " in op '" << block->getParentOp()->getName() << "'";
}
os << '\n';
}
}
#endif
uint32_t Serializer::getOrCreateFunctionID(StringRef fnName) {
auto funcID = funcIDMap.lookup(fnName);
if (!funcID) {
funcID = getNextID();
funcIDMap[fnName] = funcID;
}
return funcID;
}
void Serializer::processCapability() {
for (auto cap : module.getVceTriple()->getCapabilities())
encodeInstructionInto(capabilities, spirv::Opcode::OpCapability,
{static_cast<uint32_t>(cap)});
}
void Serializer::processDebugInfo() {
if (!options.emitDebugInfo)
return;
auto fileLoc = dyn_cast<FileLineColLoc>(module.getLoc());
auto fileName = fileLoc ? fileLoc.getFilename().strref() : "<unknown>";
fileID = getNextID();
SmallVector<uint32_t, 16> operands;
operands.push_back(fileID);
spirv::encodeStringLiteralInto(operands, fileName);
encodeInstructionInto(debug, spirv::Opcode::OpString, operands);
}
void Serializer::processExtension() {
llvm::SmallVector<uint32_t, 16> extName;
for (spirv::Extension ext : module.getVceTriple()->getExtensions()) {
extName.clear();
spirv::encodeStringLiteralInto(extName, spirv::stringifyExtension(ext));
encodeInstructionInto(extensions, spirv::Opcode::OpExtension, extName);
}
}
void Serializer::processMemoryModel() {
StringAttr memoryModelName = module.getMemoryModelAttrName();
auto mm = static_cast<uint32_t>(
module->getAttrOfType<spirv::MemoryModelAttr>(memoryModelName)
.getValue());
StringAttr addressingModelName = module.getAddressingModelAttrName();
auto am = static_cast<uint32_t>(
module->getAttrOfType<spirv::AddressingModelAttr>(addressingModelName)
.getValue());
encodeInstructionInto(memoryModel, spirv::Opcode::OpMemoryModel, {am, mm});
}
static std::string getDecorationName(StringRef attrName) {
if (attrName == "fp_fast_math_mode")
return "FPFastMathMode";
return llvm::convertToCamelFromSnakeCase(attrName, true);
}
LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
Decoration decoration,
Attribute attr) {
SmallVector<uint32_t, 1> args;
switch (decoration) {
case spirv::Decoration::LinkageAttributes: {
auto linkageAttr = llvm::dyn_cast<spirv::LinkageAttributesAttr>(attr);
auto linkageName = linkageAttr.getLinkageName();
auto linkageType = linkageAttr.getLinkageType().getValue();
spirv::encodeStringLiteralInto(args, linkageName);
args.push_back(static_cast<uint32_t>(linkageType));
break;
}
case spirv::Decoration::FPFastMathMode:
if (auto intAttr = dyn_cast<FPFastMathModeAttr>(attr)) {
args.push_back(static_cast<uint32_t>(intAttr.getValue()));
break;
}
return emitError(loc, "expected FPFastMathModeAttr attribute for ")
<< stringifyDecoration(decoration);
case spirv::Decoration::Binding:
case spirv::Decoration::DescriptorSet:
case spirv::Decoration::Location:
if (auto intAttr = dyn_cast<IntegerAttr>(attr)) {
args.push_back(intAttr.getValue().getZExtValue());
break;
}
return emitError(loc, "expected integer attribute for ")
<< stringifyDecoration(decoration);
case spirv::Decoration::BuiltIn:
if (auto strAttr = dyn_cast<StringAttr>(attr)) {
auto enumVal = spirv::symbolizeBuiltIn(strAttr.getValue());
if (enumVal) {
args.push_back(static_cast<uint32_t>(*enumVal));
break;
}
return emitError(loc, "invalid ")
<< stringifyDecoration(decoration) << " decoration attribute "
<< strAttr.getValue();
}
return emitError(loc, "expected string attribute for ")
<< stringifyDecoration(decoration);
case spirv::Decoration::Aliased:
case spirv::Decoration::AliasedPointer:
case spirv::Decoration::Flat:
case spirv::Decoration::NonReadable:
case spirv::Decoration::NonWritable:
case spirv::Decoration::NoPerspective:
case spirv::Decoration::NoSignedWrap:
case spirv::Decoration::NoUnsignedWrap:
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
case spirv::Decoration::RestrictPointer:
case spirv::Decoration::NoContraction:
if (isa<UnitAttr, DecorationAttr>(attr))
break;
return emitError(loc,
"expected unit attribute or decoration attribute for ")
<< stringifyDecoration(decoration);
default:
return emitError(loc, "unhandled decoration ")
<< stringifyDecoration(decoration);
}
return emitDecoration(resultID, decoration, args);
}
LogicalResult Serializer::processDecoration(Location loc, uint32_t resultID,
NamedAttribute attr) {
StringRef attrName = attr.getName().strref();
std::string decorationName = getDecorationName(attrName);
std::optional<Decoration> decoration =
spirv::symbolizeDecoration(decorationName);
if (!decoration) {
return emitError(
loc, "non-argument attributes expected to have snake-case-ified "
"decoration name, unhandled attribute with name : ")
<< attrName;
}
return processDecorationAttr(loc, resultID, *decoration, attr.getValue());
}
LogicalResult Serializer::processName(uint32_t resultID, StringRef name) {
assert(!name.empty() && "unexpected empty string for OpName");
if (!options.emitSymbolName)
return success();
SmallVector<uint32_t, 4> nameOperands;
nameOperands.push_back(resultID);
spirv::encodeStringLiteralInto(nameOperands, name);
encodeInstructionInto(names, spirv::Opcode::OpName, nameOperands);
return success();
}
template <>
LogicalResult Serializer::processTypeDecoration<spirv::ArrayType>(
Location loc, spirv::ArrayType type, uint32_t resultID) {
if (unsigned stride = type.getArrayStride()) {
return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
}
return success();
}
template <>
LogicalResult Serializer::processTypeDecoration<spirv::RuntimeArrayType>(
Location loc, spirv::RuntimeArrayType type, uint32_t resultID) {
if (unsigned stride = type.getArrayStride()) {
return emitDecoration(resultID, spirv::Decoration::ArrayStride, {stride});
}
return success();
}
LogicalResult Serializer::processMemberDecoration(
uint32_t structID,
const spirv::StructType::MemberDecorationInfo &memberDecoration) {
SmallVector<uint32_t, 4> args(
{structID, memberDecoration.memberIndex,
static_cast<uint32_t>(memberDecoration.decoration)});
if (memberDecoration.hasValue) {
args.push_back(memberDecoration.decorationValue);
}
encodeInstructionInto(decorations, spirv::Opcode::OpMemberDecorate, args);
return success();
}
bool Serializer::isInterfaceStructPtrType(Type type) const {
if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
switch (ptrType.getStorageClass()) {
case spirv::StorageClass::PhysicalStorageBuffer:
case spirv::StorageClass::PushConstant:
case spirv::StorageClass::StorageBuffer:
case spirv::StorageClass::Uniform:
return isa<spirv::StructType>(ptrType.getPointeeType());
default:
break;
}
}
return false;
}
LogicalResult Serializer::processType(Location loc, Type type,
uint32_t &typeID) {
SetVector<StringRef> serializationCtx;
return processTypeImpl(loc, type, typeID, serializationCtx);
}
LogicalResult
Serializer::processTypeImpl(Location loc, Type type, uint32_t &typeID,
SetVector<StringRef> &serializationCtx) {
typeID = getTypeID(type);
if (typeID)
return success();
typeID = getNextID();
SmallVector<uint32_t, 4> operands;
operands.push_back(typeID);
auto typeEnum = spirv::Opcode::OpTypeVoid;
bool deferSerialization = false;
if ((isa<FunctionType>(type) &&
succeeded(prepareFunctionType(loc, cast<FunctionType>(type), typeEnum,
operands))) ||
succeeded(prepareBasicType(loc, type, typeID, typeEnum, operands,
deferSerialization, serializationCtx))) {
if (deferSerialization)
return success();
typeIDMap[type] = typeID;
encodeInstructionInto(typesGlobalValues, typeEnum, operands);
if (recursiveStructInfos.count(type) != 0) {
for (auto &ptrInfo : recursiveStructInfos[type]) {
SmallVector<uint32_t, 4> ptrOperands;
ptrOperands.push_back(ptrInfo.pointerTypeID);
ptrOperands.push_back(static_cast<uint32_t>(ptrInfo.storageClass));
ptrOperands.push_back(typeIDMap[type]);
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpTypePointer,
ptrOperands);
}
recursiveStructInfos[type].clear();
}
return success();
}
return failure();
}
LogicalResult Serializer::prepareBasicType(
Location loc, Type type, uint32_t resultID, spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands, bool &deferSerialization,
SetVector<StringRef> &serializationCtx) {
deferSerialization = false;
if (isVoidType(type)) {
typeEnum = spirv::Opcode::OpTypeVoid;
return success();
}
if (auto intType = dyn_cast<IntegerType>(type)) {
if (intType.getWidth() == 1) {
typeEnum = spirv::Opcode::OpTypeBool;
return success();
}
typeEnum = spirv::Opcode::OpTypeInt;
operands.push_back(intType.getWidth());
operands.push_back(intType.isSigned() ? 1 : 0);
return success();
}
if (auto floatType = dyn_cast<FloatType>(type)) {
typeEnum = spirv::Opcode::OpTypeFloat;
operands.push_back(floatType.getWidth());
return success();
}
if (auto vectorType = dyn_cast<VectorType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, vectorType.getElementType(), elementTypeID,
serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeVector;
operands.push_back(elementTypeID);
operands.push_back(vectorType.getNumElements());
return success();
}
if (auto imageType = dyn_cast<spirv::ImageType>(type)) {
typeEnum = spirv::Opcode::OpTypeImage;
uint32_t sampledTypeID = 0;
if (failed(processType(loc, imageType.getElementType(), sampledTypeID)))
return failure();
llvm::append_values(operands, sampledTypeID,
static_cast<uint32_t>(imageType.getDim()),
static_cast<uint32_t>(imageType.getDepthInfo()),
static_cast<uint32_t>(imageType.getArrayedInfo()),
static_cast<uint32_t>(imageType.getSamplingInfo()),
static_cast<uint32_t>(imageType.getSamplerUseInfo()),
static_cast<uint32_t>(imageType.getImageFormat()));
return success();
}
if (auto arrayType = dyn_cast<spirv::ArrayType>(type)) {
typeEnum = spirv::Opcode::OpTypeArray;
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, arrayType.getElementType(), elementTypeID,
serializationCtx))) {
return failure();
}
operands.push_back(elementTypeID);
if (auto elementCountID = prepareConstantInt(
loc, mlirBuilder.getI32IntegerAttr(arrayType.getNumElements()))) {
operands.push_back(elementCountID);
}
return processTypeDecoration(loc, arrayType, resultID);
}
if (auto ptrType = dyn_cast<spirv::PointerType>(type)) {
uint32_t pointeeTypeID = 0;
spirv::StructType pointeeStruct =
dyn_cast<spirv::StructType>(ptrType.getPointeeType());
if (pointeeStruct && pointeeStruct.isIdentified() &&
serializationCtx.count(pointeeStruct.getIdentifier()) != 0) {
SmallVector<uint32_t, 2> forwardPtrOperands;
forwardPtrOperands.push_back(resultID);
forwardPtrOperands.push_back(
static_cast<uint32_t>(ptrType.getStorageClass()));
encodeInstructionInto(typesGlobalValues,
spirv::Opcode::OpTypeForwardPointer,
forwardPtrOperands);
auto structType = spirv::StructType::getIdentified(
module.getContext(), pointeeStruct.getIdentifier());
if (!structType)
return failure();
deferSerialization = true;
recursiveStructInfos[structType].push_back(
{resultID, ptrType.getStorageClass()});
} else {
if (failed(processTypeImpl(loc, ptrType.getPointeeType(), pointeeTypeID,
serializationCtx)))
return failure();
}
typeEnum = spirv::Opcode::OpTypePointer;
operands.push_back(static_cast<uint32_t>(ptrType.getStorageClass()));
operands.push_back(pointeeTypeID);
if (isInterfaceStructPtrType(ptrType)) {
if (failed(emitDecoration(getTypeID(pointeeStruct),
spirv::Decoration::Block)))
return emitError(loc, "cannot decorate ")
<< pointeeStruct << " with Block decoration";
}
return success();
}
if (auto runtimeArrayType = dyn_cast<spirv::RuntimeArrayType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, runtimeArrayType.getElementType(),
elementTypeID, serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeRuntimeArray;
operands.push_back(elementTypeID);
return processTypeDecoration(loc, runtimeArrayType, resultID);
}
if (auto sampledImageType = dyn_cast<spirv::SampledImageType>(type)) {
typeEnum = spirv::Opcode::OpTypeSampledImage;
uint32_t imageTypeID = 0;
if (failed(
processType(loc, sampledImageType.getImageType(), imageTypeID))) {
return failure();
}
operands.push_back(imageTypeID);
return success();
}
if (auto structType = dyn_cast<spirv::StructType>(type)) {
if (structType.isIdentified()) {
if (failed(processName(resultID, structType.getIdentifier())))
return failure();
serializationCtx.insert(structType.getIdentifier());
}
bool hasOffset = structType.hasOffset();
for (auto elementIndex :
llvm::seq<uint32_t>(0, structType.getNumElements())) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, structType.getElementType(elementIndex),
elementTypeID, serializationCtx))) {
return failure();
}
operands.push_back(elementTypeID);
if (hasOffset) {
spirv::StructType::MemberDecorationInfo offsetDecoration{
elementIndex, 1, spirv::Decoration::Offset,
static_cast<uint32_t>(structType.getMemberOffset(elementIndex))};
if (failed(processMemberDecoration(resultID, offsetDecoration))) {
return emitError(loc, "cannot decorate ")
<< elementIndex << "-th member of " << structType
<< " with its offset";
}
}
}
SmallVector<spirv::StructType::MemberDecorationInfo, 4> memberDecorations;
structType.getMemberDecorations(memberDecorations);
for (auto &memberDecoration : memberDecorations) {
if (failed(processMemberDecoration(resultID, memberDecoration))) {
return emitError(loc, "cannot decorate ")
<< static_cast<uint32_t>(memberDecoration.memberIndex)
<< "-th member of " << structType << " with "
<< stringifyDecoration(memberDecoration.decoration);
}
}
typeEnum = spirv::Opcode::OpTypeStruct;
if (structType.isIdentified())
serializationCtx.remove(structType.getIdentifier());
return success();
}
if (auto cooperativeMatrixType =
dyn_cast<spirv::CooperativeMatrixType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, cooperativeMatrixType.getElementType(),
elementTypeID, serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeCooperativeMatrixKHR;
auto getConstantOp = [&](uint32_t id) {
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
return prepareConstantInt(loc, attr);
};
llvm::append_values(
operands, elementTypeID,
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getScope())),
getConstantOp(cooperativeMatrixType.getRows()),
getConstantOp(cooperativeMatrixType.getColumns()),
getConstantOp(static_cast<uint32_t>(cooperativeMatrixType.getUse())));
return success();
}
if (auto jointMatrixType = dyn_cast<spirv::JointMatrixINTELType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, jointMatrixType.getElementType(),
elementTypeID, serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeJointMatrixINTEL;
auto getConstantOp = [&](uint32_t id) {
auto attr = IntegerAttr::get(IntegerType::get(type.getContext(), 32), id);
return prepareConstantInt(loc, attr);
};
llvm::append_values(
operands, elementTypeID, getConstantOp(jointMatrixType.getRows()),
getConstantOp(jointMatrixType.getColumns()),
getConstantOp(static_cast<uint32_t>(jointMatrixType.getMatrixLayout())),
getConstantOp(static_cast<uint32_t>(jointMatrixType.getScope())));
return success();
}
if (auto matrixType = dyn_cast<spirv::MatrixType>(type)) {
uint32_t elementTypeID = 0;
if (failed(processTypeImpl(loc, matrixType.getColumnType(), elementTypeID,
serializationCtx))) {
return failure();
}
typeEnum = spirv::Opcode::OpTypeMatrix;
llvm::append_values(operands, elementTypeID, matrixType.getNumColumns());
return success();
}
return emitError(loc, "unhandled type in serialization: ") << type;
}
LogicalResult
Serializer::prepareFunctionType(Location loc, FunctionType type,
spirv::Opcode &typeEnum,
SmallVectorImpl<uint32_t> &operands) {
typeEnum = spirv::Opcode::OpTypeFunction;
assert(type.getNumResults() <= 1 &&
"serialization supports only a single return value");
uint32_t resultID = 0;
if (failed(processType(
loc, type.getNumResults() == 1 ? type.getResult(0) : getVoidType(),
resultID))) {
return failure();
}
operands.push_back(resultID);
for (auto &res : type.getInputs()) {
uint32_t argTypeID = 0;
if (failed(processType(loc, res, argTypeID))) {
return failure();
}
operands.push_back(argTypeID);
}
return success();
}
uint32_t Serializer::prepareConstant(Location loc, Type constType,
Attribute valueAttr) {
if (auto id = prepareConstantScalar(loc, valueAttr)) {
return id;
}
if (auto id = getConstantID(valueAttr)) {
return id;
}
uint32_t typeID = 0;
if (failed(processType(loc, constType, typeID))) {
return 0;
}
uint32_t resultID = 0;
if (auto attr = dyn_cast<DenseElementsAttr>(valueAttr)) {
int rank = dyn_cast<ShapedType>(attr.getType()).getRank();
SmallVector<uint64_t, 4> index(rank);
resultID = prepareDenseElementsConstant(loc, constType, attr,
0, index);
} else if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
resultID = prepareArrayConstant(loc, constType, arrayAttr);
}
if (resultID == 0) {
emitError(loc, "cannot serialize attribute: ") << valueAttr;
return 0;
}
constIDMap[valueAttr] = resultID;
return resultID;
}
uint32_t Serializer::prepareArrayConstant(Location loc, Type constType,
ArrayAttr attr) {
uint32_t typeID = 0;
if (failed(processType(loc, constType, typeID))) {
return 0;
}
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
operands.reserve(attr.size() + 2);
auto elementType = cast<spirv::ArrayType>(constType).getElementType();
for (Attribute elementAttr : attr) {
if (auto elementID = prepareConstant(loc, elementType, elementAttr)) {
operands.push_back(elementID);
} else {
return 0;
}
}
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
encodeInstructionInto(typesGlobalValues, opcode, operands);
return resultID;
}
uint32_t
Serializer::prepareDenseElementsConstant(Location loc, Type constType,
DenseElementsAttr valueAttr, int dim,
MutableArrayRef<uint64_t> index) {
auto shapedType = dyn_cast<ShapedType>(valueAttr.getType());
assert(dim <= shapedType.getRank());
if (shapedType.getRank() == dim) {
if (auto attr = dyn_cast<DenseIntElementsAttr>(valueAttr)) {
return attr.getType().getElementType().isInteger(1)
? prepareConstantBool(loc, attr.getValues<BoolAttr>()[index])
: prepareConstantInt(loc,
attr.getValues<IntegerAttr>()[index]);
}
if (auto attr = dyn_cast<DenseFPElementsAttr>(valueAttr)) {
return prepareConstantFp(loc, attr.getValues<FloatAttr>()[index]);
}
return 0;
}
uint32_t typeID = 0;
if (failed(processType(loc, constType, typeID))) {
return 0;
}
uint32_t resultID = getNextID();
SmallVector<uint32_t, 4> operands = {typeID, resultID};
operands.reserve(shapedType.getDimSize(dim) + 2);
auto elementType = cast<spirv::CompositeType>(constType).getElementType(0);
for (int i = 0; i < shapedType.getDimSize(dim); ++i) {
index[dim] = i;
if (auto elementID = prepareDenseElementsConstant(
loc, elementType, valueAttr, dim + 1, index)) {
operands.push_back(elementID);
} else {
return 0;
}
}
spirv::Opcode opcode = spirv::Opcode::OpConstantComposite;
encodeInstructionInto(typesGlobalValues, opcode, operands);
return resultID;
}
uint32_t Serializer::prepareConstantScalar(Location loc, Attribute valueAttr,
bool isSpec) {
if (auto floatAttr = dyn_cast<FloatAttr>(valueAttr)) {
return prepareConstantFp(loc, floatAttr, isSpec);
}
if (auto boolAttr = dyn_cast<BoolAttr>(valueAttr)) {
return prepareConstantBool(loc, boolAttr, isSpec);
}
if (auto intAttr = dyn_cast<IntegerAttr>(valueAttr)) {
return prepareConstantInt(loc, intAttr, isSpec);
}
return 0;
}
uint32_t Serializer::prepareConstantBool(Location loc, BoolAttr boolAttr,
bool isSpec) {
if (!isSpec) {
if (auto id = getConstantID(boolAttr)) {
return id;
}
}
uint32_t typeID = 0;
if (failed(processType(loc, cast<IntegerAttr>(boolAttr).getType(), typeID))) {
return 0;
}
auto resultID = getNextID();
auto opcode = boolAttr.getValue()
? (isSpec ? spirv::Opcode::OpSpecConstantTrue
: spirv::Opcode::OpConstantTrue)
: (isSpec ? spirv::Opcode::OpSpecConstantFalse
: spirv::Opcode::OpConstantFalse);
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID});
if (!isSpec) {
constIDMap[boolAttr] = resultID;
}
return resultID;
}
uint32_t Serializer::prepareConstantInt(Location loc, IntegerAttr intAttr,
bool isSpec) {
if (!isSpec) {
if (auto id = getConstantID(intAttr)) {
return id;
}
}
uint32_t typeID = 0;
if (failed(processType(loc, intAttr.getType(), typeID))) {
return 0;
}
auto resultID = getNextID();
APInt value = intAttr.getValue();
unsigned bitwidth = value.getBitWidth();
bool isSigned = intAttr.getType().isSignedInteger();
auto opcode =
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
switch (bitwidth) {
case 32:
case 16:
case 8: {
uint32_t word = 0;
if (isSigned) {
word = static_cast<int32_t>(value.getSExtValue());
} else {
word = static_cast<uint32_t>(value.getZExtValue());
}
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} break;
case 64: {
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words;
if (isSigned) {
words = llvm::bit_cast<DoubleWord>(value.getSExtValue());
} else {
words = llvm::bit_cast<DoubleWord>(value.getZExtValue());
}
encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
} break;
default: {
std::string valueStr;
llvm::raw_string_ostream rss(valueStr);
value.print(rss, false);
emitError(loc, "cannot serialize ")
<< bitwidth << "-bit integer literal: " << rss.str();
return 0;
}
}
if (!isSpec) {
constIDMap[intAttr] = resultID;
}
return resultID;
}
uint32_t Serializer::prepareConstantFp(Location loc, FloatAttr floatAttr,
bool isSpec) {
if (!isSpec) {
if (auto id = getConstantID(floatAttr)) {
return id;
}
}
uint32_t typeID = 0;
if (failed(processType(loc, floatAttr.getType(), typeID))) {
return 0;
}
auto resultID = getNextID();
APFloat value = floatAttr.getValue();
APInt intValue = value.bitcastToAPInt();
auto opcode =
isSpec ? spirv::Opcode::OpSpecConstant : spirv::Opcode::OpConstant;
if (&value.getSemantics() == &APFloat::IEEEsingle()) {
uint32_t word = llvm::bit_cast<uint32_t>(value.convertToFloat());
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} else if (&value.getSemantics() == &APFloat::IEEEdouble()) {
struct DoubleWord {
uint32_t word1;
uint32_t word2;
} words = llvm::bit_cast<DoubleWord>(value.convertToDouble());
encodeInstructionInto(typesGlobalValues, opcode,
{typeID, resultID, words.word1, words.word2});
} else if (&value.getSemantics() == &APFloat::IEEEhalf()) {
uint32_t word =
static_cast<uint32_t>(value.bitcastToAPInt().getZExtValue());
encodeInstructionInto(typesGlobalValues, opcode, {typeID, resultID, word});
} else {
std::string valueStr;
llvm::raw_string_ostream rss(valueStr);
value.print(rss);
emitError(loc, "cannot serialize ")
<< floatAttr.getType() << "-typed float literal: " << rss.str();
return 0;
}
if (!isSpec) {
constIDMap[floatAttr] = resultID;
}
return resultID;
}
uint32_t Serializer::getOrCreateBlockID(Block *block) {
if (uint32_t id = getBlockID(block))
return id;
return blockIDMap[block] = getNextID();
}
#ifndef NDEBUG
void Serializer::printBlock(Block *block, raw_ostream &os) {
os << "block " << block << " (id = ";
if (uint32_t id = getBlockID(block))
os << id;
else
os << "unknown";
os << ")\n";
}
#endif
LogicalResult
Serializer::processBlock(Block *block, bool omitLabel,
function_ref<LogicalResult()> emitMerge) {
LLVM_DEBUG(llvm::dbgs() << "processing block " << block << ":\n");
LLVM_DEBUG(block->print(llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << '\n');
if (!omitLabel) {
uint32_t blockID = getOrCreateBlockID(block);
LLVM_DEBUG(printBlock(block, llvm::dbgs()));
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
}
if (failed(emitPhiForBlockArguments(block)))
return failure();
if (emitMerge &&
llvm::any_of(block->getOperations(),
llvm::IsaPred<spirv::LoopOp, spirv::SelectionOp>)) {
if (failed(emitMerge()))
return failure();
emitMerge = nullptr;
uint32_t blockID = getNextID();
encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {blockID});
encodeInstructionInto(functionBody, spirv::Opcode::OpLabel, {blockID});
}
for (Operation &op : llvm::drop_end(*block)) {
if (failed(processOperation(&op)))
return failure();
}
if (emitMerge)
if (failed(emitMerge()))
return failure();
if (failed(processOperation(&block->back())))
return failure();
return success();
}
LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
if (block->args_empty() || block->isEntryBlock())
return success();
LLVM_DEBUG(llvm::dbgs() << "emitting phi instructions..\n");
SmallVector<std::pair<Block *, OperandRange>, 4> predecessors;
for (Block *mlirPredecessor : block->getPredecessors()) {
auto *terminator = mlirPredecessor->getTerminator();
LLVM_DEBUG(llvm::dbgs() << " mlir predecessor ");
LLVM_DEBUG(printBlock(mlirPredecessor, llvm::dbgs()));
LLVM_DEBUG(llvm::dbgs() << " terminator: " << *terminator << "\n");
Block *spirvPredecessor = getPhiIncomingBlock(mlirPredecessor);
LLVM_DEBUG(llvm::dbgs() << " spirv predecessor ");
LLVM_DEBUG(printBlock(spirvPredecessor, llvm::dbgs()));
if (auto branchOp = dyn_cast<spirv::BranchOp>(terminator)) {
predecessors.emplace_back(spirvPredecessor, branchOp.getOperands());
} else if (auto branchCondOp =
dyn_cast<spirv::BranchConditionalOp>(terminator)) {
std::optional<OperandRange> blockOperands;
if (branchCondOp.getTrueTarget() == block) {
blockOperands = branchCondOp.getTrueTargetOperands();
} else {
assert(branchCondOp.getFalseTarget() == block);
blockOperands = branchCondOp.getFalseTargetOperands();
}
assert(!blockOperands->empty() &&
"expected non-empty block operand range");
predecessors.emplace_back(spirvPredecessor, *blockOperands);
} else {
return terminator->emitError("unimplemented terminator for Phi creation");
}
LLVM_DEBUG({
llvm::dbgs() << " block arguments:\n";
for (Value v : predecessors.back().second)
llvm::dbgs() << " " << v << "\n";
});
}
for (auto argIndex : llvm::seq<unsigned>(0, block->getNumArguments())) {
BlockArgument arg = block->getArgument(argIndex);
uint32_t phiTypeID = 0;
if (failed(processType(arg.getLoc(), arg.getType(), phiTypeID)))
return failure();
uint32_t phiID = getNextID();
LLVM_DEBUG(llvm::dbgs() << "[phi] for block argument #" << argIndex << ' '
<< arg << " (id = " << phiID << ")\n");
SmallVector<uint32_t, 8> phiArgs;
phiArgs.push_back(phiTypeID);
phiArgs.push_back(phiID);
for (auto predIndex : llvm::seq<unsigned>(0, predecessors.size())) {
Value value = predecessors[predIndex].second[argIndex];
uint32_t predBlockId = getOrCreateBlockID(predecessors[predIndex].first);
LLVM_DEBUG(llvm::dbgs() << "[phi] use predecessor (id = " << predBlockId
<< ") value " << value << ' ');
uint32_t valueId = getValueID(value);
if (valueId == 0) {
LLVM_DEBUG(llvm::dbgs() << "(need to fix)\n");
deferredPhiValues[value].push_back(functionBody.size() + 1 +
phiArgs.size());
} else {
LLVM_DEBUG(llvm::dbgs() << "(id = " << valueId << ")\n");
}
phiArgs.push_back(valueId);
phiArgs.push_back(predBlockId);
}
encodeInstructionInto(functionBody, spirv::Opcode::OpPhi, phiArgs);
valueIDMap[arg] = phiID;
}
return success();
}
LogicalResult Serializer::encodeExtensionInstruction(
Operation *op, StringRef extensionSetName, uint32_t extensionOpcode,
ArrayRef<uint32_t> operands) {
auto &setID = extendedInstSetIDMap[extensionSetName];
if (!setID) {
setID = getNextID();
SmallVector<uint32_t, 16> importOperands;
importOperands.push_back(setID);
spirv::encodeStringLiteralInto(importOperands, extensionSetName);
encodeInstructionInto(extendedSets, spirv::Opcode::OpExtInstImport,
importOperands);
}
if (operands.size() < 2) {
return op->emitError("extended instructions must have a result encoding");
}
SmallVector<uint32_t, 8> extInstOperands;
extInstOperands.reserve(operands.size() + 2);
extInstOperands.append(operands.begin(), std::next(operands.begin(), 2));
extInstOperands.push_back(setID);
extInstOperands.push_back(extensionOpcode);
extInstOperands.append(std::next(operands.begin(), 2), operands.end());
encodeInstructionInto(functionBody, spirv::Opcode::OpExtInst,
extInstOperands);
return success();
}
LogicalResult Serializer::processOperation(Operation *opInst) {
LLVM_DEBUG(llvm::dbgs() << "[op] '" << opInst->getName() << "'\n");
return TypeSwitch<Operation *, LogicalResult>(opInst)
.Case([&](spirv::AddressOfOp op) { return processAddressOfOp(op); })
.Case([&](spirv::BranchOp op) { return processBranchOp(op); })
.Case([&](spirv::BranchConditionalOp op) {
return processBranchConditionalOp(op);
})
.Case([&](spirv::ConstantOp op) { return processConstantOp(op); })
.Case([&](spirv::FuncOp op) { return processFuncOp(op); })
.Case([&](spirv::GlobalVariableOp op) {
return processGlobalVariableOp(op);
})
.Case([&](spirv::LoopOp op) { return processLoopOp(op); })
.Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
.Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
.Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
.Case([&](spirv::SpecConstantCompositeOp op) {
return processSpecConstantCompositeOp(op);
})
.Case([&](spirv::SpecConstantOperationOp op) {
return processSpecConstantOperationOp(op);
})
.Case([&](spirv::UndefOp op) { return processUndefOp(op); })
.Case([&](spirv::VariableOp op) { return processVariableOp(op); })
.Default(
[&](Operation *op) { return dispatchToAutogenSerialization(op); });
}
LogicalResult Serializer::processOpWithoutGrammarAttr(Operation *op,
StringRef extInstSet,
uint32_t opcode) {
SmallVector<uint32_t, 4> operands;
Location loc = op->getLoc();
uint32_t resultID = 0;
if (op->getNumResults() != 0) {
uint32_t resultTypeID = 0;
if (failed(processType(loc, op->getResult(0).getType(), resultTypeID)))
return failure();
operands.push_back(resultTypeID);
resultID = getNextID();
operands.push_back(resultID);
valueIDMap[op->getResult(0)] = resultID;
};
for (Value operand : op->getOperands())
operands.push_back(getValueID(operand));
if (failed(emitDebugLine(functionBody, loc)))
return failure();
if (extInstSet.empty()) {
encodeInstructionInto(functionBody, static_cast<spirv::Opcode>(opcode),
operands);
} else {
if (failed(encodeExtensionInstruction(op, extInstSet, opcode, operands)))
return failure();
}
if (op->getNumResults() != 0) {
for (auto attr : op->getAttrs()) {
if (failed(processDecoration(loc, resultID, attr)))
return failure();
}
}
return success();
}
LogicalResult Serializer::emitDecoration(uint32_t target,
spirv::Decoration decoration,
ArrayRef<uint32_t> params) {
uint32_t wordCount = 3 + params.size();
llvm::append_values(
decorations,
spirv::getPrefixedOpcode(wordCount, spirv::Opcode::OpDecorate), target,
static_cast<uint32_t>(decoration));
llvm::append_range(decorations, params);
return success();
}
LogicalResult Serializer::emitDebugLine(SmallVectorImpl<uint32_t> &binary,
Location loc) {
if (!options.emitDebugInfo)
return success();
if (lastProcessedWasMergeInst) {
lastProcessedWasMergeInst = false;
return success();
}
auto fileLoc = dyn_cast<FileLineColLoc>(loc);
if (fileLoc)
encodeInstructionInto(binary, spirv::Opcode::OpLine,
{fileID, fileLoc.getLine(), fileLoc.getColumn()});
return success();
}
}
}