#include "clang/Driver/OffloadBundler.h"
#include "clang/Basic/Cuda.h"
#include "clang/Basic/TargetID.h"
#include "clang/Basic/Version.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/BinaryFormat/Magic.h"
#include "llvm/Object/Archive.h"
#include "llvm/Object/ArchiveWriter.h"
#include "llvm/Object/Binary.h"
#include "llvm/Object/ObjectFile.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Compression.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/EndianStream.h"
#include "llvm/Support/Errc.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorOr.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/MD5.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/Program.h"
#include "llvm/Support/Signals.h"
#include "llvm/Support/StringSaver.h"
#include "llvm/Support/Timer.h"
#include "llvm/Support/WithColor.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TargetParser/Host.h"
#include "llvm/TargetParser/Triple.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <forward_list>
#include <llvm/Support/Process.h>
#include <memory>
#include <set>
#include <string>
#include <system_error>
#include <utility>
using namespace llvm;
using namespace llvm::object;
using namespace clang;
static llvm::TimerGroup
ClangOffloadBundlerTimerGroup("Clang Offload Bundler Timer Group",
"Timer group for clang offload bundler");
#define OFFLOAD_BUNDLER_MAGIC_STR "__CLANG_OFFLOAD_BUNDLE__"
OffloadTargetInfo::OffloadTargetInfo(const StringRef Target,
const OffloadBundlerConfig &BC)
: BundlerConfig(BC) {
auto TargetFeatures = Target.split(':');
auto TripleOrGPU = TargetFeatures.first.rsplit('-');
if (clang::StringToOffloadArch(TripleOrGPU.second) !=
clang::OffloadArch::UNKNOWN) {
auto KindTriple = TripleOrGPU.first.split('-');
this->OffloadKind = KindTriple.first;
llvm::Triple t = llvm::Triple(KindTriple.second);
this->Triple = llvm::Triple(t.getArchName(), t.getVendorName(),
t.getOSName(), t.getEnvironmentName());
this->TargetID = Target.substr(Target.find(TripleOrGPU.second));
} else {
auto KindTriple = TargetFeatures.first.split('-');
this->OffloadKind = KindTriple.first;
llvm::Triple t = llvm::Triple(KindTriple.second);
this->Triple = llvm::Triple(t.getArchName(), t.getVendorName(),
t.getOSName(), t.getEnvironmentName());
this->TargetID = "";
}
}
bool OffloadTargetInfo::hasHostKind() const {
return this->OffloadKind == "host";
}
bool OffloadTargetInfo::isOffloadKindValid() const {
return OffloadKind == "host" || OffloadKind == "openmp" ||
OffloadKind == "hip" || OffloadKind == "hipv4";
}
bool OffloadTargetInfo::isOffloadKindCompatible(
const StringRef TargetOffloadKind) const {
if ((OffloadKind == TargetOffloadKind) ||
(OffloadKind == "hip" && TargetOffloadKind == "hipv4") ||
(OffloadKind == "hipv4" && TargetOffloadKind == "hip"))
return true;
if (BundlerConfig.HipOpenmpCompatible) {
bool HIPCompatibleWithOpenMP = OffloadKind.starts_with_insensitive("hip") &&
TargetOffloadKind == "openmp";
bool OpenMPCompatibleWithHIP =
OffloadKind == "openmp" &&
TargetOffloadKind.starts_with_insensitive("hip");
return HIPCompatibleWithOpenMP || OpenMPCompatibleWithHIP;
}
return false;
}
bool OffloadTargetInfo::isTripleValid() const {
return !Triple.str().empty() && Triple.getArch() != Triple::UnknownArch;
}
bool OffloadTargetInfo::operator==(const OffloadTargetInfo &Target) const {
return OffloadKind == Target.OffloadKind &&
Triple.isCompatibleWith(Target.Triple) && TargetID == Target.TargetID;
}
std::string OffloadTargetInfo::str() const {
return Twine(OffloadKind + "-" + Triple.str() + "-" + TargetID).str();
}
static StringRef getDeviceFileExtension(StringRef Device,
StringRef BundleFileName) {
if (Device.contains("gfx"))
return ".bc";
if (Device.contains("sm_"))
return ".cubin";
return sys::path::extension(BundleFileName);
}
static std::string getDeviceLibraryFileName(StringRef BundleFileName,
StringRef Device) {
StringRef LibName = sys::path::stem(BundleFileName);
StringRef Extension = getDeviceFileExtension(Device, BundleFileName);
std::string Result;
Result += LibName;
Result += Extension;
return Result;
}
namespace {
class FileHandler {
public:
struct BundleInfo {
StringRef BundleID;
};
FileHandler() {}
virtual ~FileHandler() {}
virtual Error ReadHeader(MemoryBuffer &Input) = 0;
virtual Expected<std::optional<StringRef>>
ReadBundleStart(MemoryBuffer &Input) = 0;
virtual Error ReadBundleEnd(MemoryBuffer &Input) = 0;
virtual Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) = 0;
virtual Error WriteHeader(raw_ostream &OS,
ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) = 0;
virtual Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) = 0;
virtual Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) = 0;
virtual Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) = 0;
virtual Error finalizeOutputFile() { return Error::success(); }
virtual Error listBundleIDs(MemoryBuffer &Input) {
if (Error Err = ReadHeader(Input))
return Err;
return forEachBundle(Input, [&](const BundleInfo &Info) -> Error {
llvm::outs() << Info.BundleID << '\n';
Error Err = listBundleIDsCallback(Input, Info);
if (Err)
return Err;
return Error::success();
});
}
virtual Error getBundleIDs(MemoryBuffer &Input,
std::set<StringRef> &BundleIds) {
if (Error Err = ReadHeader(Input))
return Err;
return forEachBundle(Input, [&](const BundleInfo &Info) -> Error {
BundleIds.insert(Info.BundleID);
Error Err = listBundleIDsCallback(Input, Info);
if (Err)
return Err;
return Error::success();
});
}
Error forEachBundle(MemoryBuffer &Input,
std::function<Error(const BundleInfo &)> Func) {
while (true) {
Expected<std::optional<StringRef>> CurTripleOrErr =
ReadBundleStart(Input);
if (!CurTripleOrErr)
return CurTripleOrErr.takeError();
if (!*CurTripleOrErr)
break;
StringRef CurTriple = **CurTripleOrErr;
assert(!CurTriple.empty());
BundleInfo Info{CurTriple};
if (Error Err = Func(Info))
return Err;
}
return Error::success();
}
protected:
virtual Error listBundleIDsCallback(MemoryBuffer &Input,
const BundleInfo &Info) {
return Error::success();
}
};
static uint64_t Read8byteIntegerFromBuffer(StringRef Buffer, size_t pos) {
return llvm::support::endian::read64le(Buffer.data() + pos);
}
static void Write8byteIntegerToBuffer(raw_ostream &OS, uint64_t Val) {
llvm::support::endian::write(OS, Val, llvm::endianness::little);
}
class BinaryFileHandler final : public FileHandler {
struct BinaryBundleInfo final : public BundleInfo {
uint64_t Size = 0u;
uint64_t Offset = 0u;
BinaryBundleInfo() {}
BinaryBundleInfo(uint64_t Size, uint64_t Offset)
: Size(Size), Offset(Offset) {}
};
StringMap<BinaryBundleInfo> BundlesInfo;
StringMap<BinaryBundleInfo>::iterator CurBundleInfo;
StringMap<BinaryBundleInfo>::iterator NextBundleInfo;
std::string CurWriteBundleTarget;
const OffloadBundlerConfig &BundlerConfig;
public:
BinaryFileHandler(const OffloadBundlerConfig &BC) : BundlerConfig(BC) {}
~BinaryFileHandler() final {}
Error ReadHeader(MemoryBuffer &Input) final {
StringRef FC = Input.getBuffer();
CurBundleInfo = BundlesInfo.end();
size_t ReadChars = sizeof(OFFLOAD_BUNDLER_MAGIC_STR) - 1;
if (ReadChars > FC.size())
return Error::success();
if (llvm::identify_magic(FC) != llvm::file_magic::offload_bundle)
return Error::success();
if (ReadChars + 8 > FC.size())
return Error::success();
uint64_t NumberOfBundles = Read8byteIntegerFromBuffer(FC, ReadChars);
ReadChars += 8;
for (uint64_t i = 0; i < NumberOfBundles; ++i) {
if (ReadChars + 8 > FC.size())
return Error::success();
uint64_t Offset = Read8byteIntegerFromBuffer(FC, ReadChars);
ReadChars += 8;
if (ReadChars + 8 > FC.size())
return Error::success();
uint64_t Size = Read8byteIntegerFromBuffer(FC, ReadChars);
ReadChars += 8;
if (ReadChars + 8 > FC.size())
return Error::success();
uint64_t TripleSize = Read8byteIntegerFromBuffer(FC, ReadChars);
ReadChars += 8;
if (ReadChars + TripleSize > FC.size())
return Error::success();
StringRef Triple(&FC.data()[ReadChars], TripleSize);
ReadChars += TripleSize;
if (!Offset || Offset + Size > FC.size())
return Error::success();
assert(!BundlesInfo.contains(Triple) && "Triple is duplicated??");
BundlesInfo[Triple] = BinaryBundleInfo(Size, Offset);
}
CurBundleInfo = BundlesInfo.end();
NextBundleInfo = BundlesInfo.begin();
return Error::success();
}
Expected<std::optional<StringRef>>
ReadBundleStart(MemoryBuffer &Input) final {
if (NextBundleInfo == BundlesInfo.end())
return std::nullopt;
CurBundleInfo = NextBundleInfo++;
return CurBundleInfo->first();
}
Error ReadBundleEnd(MemoryBuffer &Input) final {
assert(CurBundleInfo != BundlesInfo.end() && "Invalid reader info!");
return Error::success();
}
Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final {
assert(CurBundleInfo != BundlesInfo.end() && "Invalid reader info!");
StringRef FC = Input.getBuffer();
OS.write(FC.data() + CurBundleInfo->second.Offset,
CurBundleInfo->second.Size);
return Error::success();
}
Error WriteHeader(raw_ostream &OS,
ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) final {
uint64_t HeaderSize = 0;
HeaderSize += sizeof(OFFLOAD_BUNDLER_MAGIC_STR) - 1;
HeaderSize += 8;
for (auto &T : BundlerConfig.TargetNames) {
HeaderSize += 3 * 8;
HeaderSize += T.size();
}
OS << OFFLOAD_BUNDLER_MAGIC_STR;
Write8byteIntegerToBuffer(OS, BundlerConfig.TargetNames.size());
unsigned Idx = 0;
for (auto &T : BundlerConfig.TargetNames) {
MemoryBuffer &MB = *Inputs[Idx++];
HeaderSize = alignTo(HeaderSize, BundlerConfig.BundleAlignment);
Write8byteIntegerToBuffer(OS, HeaderSize);
Write8byteIntegerToBuffer(OS, MB.getBufferSize());
BundlesInfo[T] = BinaryBundleInfo(MB.getBufferSize(), HeaderSize);
HeaderSize += MB.getBufferSize();
Write8byteIntegerToBuffer(OS, T.size());
OS << T;
}
return Error::success();
}
Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) final {
CurWriteBundleTarget = TargetTriple.str();
return Error::success();
}
Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) final {
return Error::success();
}
Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) final {
auto BI = BundlesInfo[CurWriteBundleTarget];
size_t CurrentPos = OS.tell();
size_t PaddingSize = BI.Offset > CurrentPos ? BI.Offset - CurrentPos : 0;
for (size_t I = 0; I < PaddingSize; ++I)
OS.write('\0');
assert(OS.tell() == BI.Offset);
OS.write(Input.getBufferStart(), Input.getBufferSize());
return Error::success();
}
};
class TempFileHandlerRAII {
public:
~TempFileHandlerRAII() {
for (const auto &File : Files)
sys::fs::remove(File);
}
Expected<StringRef> Create(std::optional<ArrayRef<char>> Contents) {
SmallString<128u> File;
if (std::error_code EC =
sys::fs::createTemporaryFile("clang-offload-bundler", "tmp", File))
return createFileError(File, EC);
Files.push_front(File);
if (Contents) {
std::error_code EC;
raw_fd_ostream OS(File, EC);
if (EC)
return createFileError(File, EC);
OS.write(Contents->data(), Contents->size());
}
return Files.front().str();
}
private:
std::forward_list<SmallString<128u>> Files;
};
class ObjectFileHandler final : public FileHandler {
std::unique_ptr<ObjectFile> Obj;
StringRef getInputFileContents() const { return Obj->getData(); }
static Expected<std::optional<StringRef>>
IsOffloadSection(SectionRef CurSection) {
Expected<StringRef> NameOrErr = CurSection.getName();
if (!NameOrErr)
return NameOrErr.takeError();
if (llvm::identify_magic(*NameOrErr) != llvm::file_magic::offload_bundle)
return std::nullopt;
return NameOrErr->substr(sizeof(OFFLOAD_BUNDLER_MAGIC_STR) - 1);
}
unsigned NumberOfInputs = 0;
unsigned NumberOfProcessedInputs = 0;
section_iterator CurrentSection;
section_iterator NextSection;
const OffloadBundlerConfig &BundlerConfig;
public:
ObjectFileHandler(std::unique_ptr<ObjectFile> ObjIn,
const OffloadBundlerConfig &BC)
: Obj(std::move(ObjIn)), CurrentSection(Obj->section_begin()),
NextSection(Obj->section_begin()), BundlerConfig(BC) {}
~ObjectFileHandler() final {}
Error ReadHeader(MemoryBuffer &Input) final { return Error::success(); }
Expected<std::optional<StringRef>>
ReadBundleStart(MemoryBuffer &Input) final {
while (NextSection != Obj->section_end()) {
CurrentSection = NextSection;
++NextSection;
Expected<std::optional<StringRef>> TripleOrErr =
IsOffloadSection(*CurrentSection);
if (!TripleOrErr)
return TripleOrErr.takeError();
if (*TripleOrErr)
return **TripleOrErr;
}
return std::nullopt;
}
Error ReadBundleEnd(MemoryBuffer &Input) final { return Error::success(); }
Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final {
Expected<StringRef> ContentOrErr = CurrentSection->getContents();
if (!ContentOrErr)
return ContentOrErr.takeError();
StringRef Content = *ContentOrErr;
std::string ModifiedContent;
if (Content.size() == 1u && Content.front() == 0) {
auto HostBundleOrErr = getHostBundle(
StringRef(Input.getBufferStart(), Input.getBufferSize()));
if (!HostBundleOrErr)
return HostBundleOrErr.takeError();
ModifiedContent = std::move(*HostBundleOrErr);
Content = ModifiedContent;
}
OS.write(Content.data(), Content.size());
return Error::success();
}
Error WriteHeader(raw_ostream &OS,
ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) final {
assert(BundlerConfig.HostInputIndex != ~0u &&
"Host input index not defined.");
NumberOfInputs = Inputs.size();
return Error::success();
}
Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) final {
++NumberOfProcessedInputs;
return Error::success();
}
Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) final {
return Error::success();
}
Error finalizeOutputFile() final {
assert(NumberOfProcessedInputs <= NumberOfInputs &&
"Processing more inputs that actually exist!");
assert(BundlerConfig.HostInputIndex != ~0u &&
"Host input index not defined.");
if (NumberOfProcessedInputs != NumberOfInputs)
return Error::success();
assert(BundlerConfig.ObjcopyPath != "" &&
"llvm-objcopy path not specified");
TempFileHandlerRAII TempFiles;
BumpPtrAllocator Alloc;
StringSaver SS{Alloc};
SmallVector<StringRef, 8u> ObjcopyArgs{"llvm-objcopy"};
for (unsigned I = 0; I < NumberOfInputs; ++I) {
StringRef InputFile = BundlerConfig.InputFileNames[I];
if (I == BundlerConfig.HostInputIndex) {
Expected<StringRef> TempFileOrErr = TempFiles.Create(ArrayRef<char>(0));
if (!TempFileOrErr)
return TempFileOrErr.takeError();
InputFile = *TempFileOrErr;
}
ObjcopyArgs.push_back(
SS.save(Twine("--add-section=") + OFFLOAD_BUNDLER_MAGIC_STR +
BundlerConfig.TargetNames[I] + "=" + InputFile));
ObjcopyArgs.push_back(
SS.save(Twine("--set-section-flags=") + OFFLOAD_BUNDLER_MAGIC_STR +
BundlerConfig.TargetNames[I] + "=readonly,exclude"));
}
ObjcopyArgs.push_back("--");
ObjcopyArgs.push_back(
BundlerConfig.InputFileNames[BundlerConfig.HostInputIndex]);
ObjcopyArgs.push_back(BundlerConfig.OutputFileNames.front());
if (Error Err = executeObjcopy(BundlerConfig.ObjcopyPath, ObjcopyArgs))
return Err;
return Error::success();
}
Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) final {
return Error::success();
}
private:
Error executeObjcopy(StringRef Objcopy, ArrayRef<StringRef> Args) {
if (BundlerConfig.PrintExternalCommands) {
errs() << "\"" << Objcopy << "\"";
for (StringRef Arg : drop_begin(Args, 1))
errs() << " \"" << Arg << "\"";
errs() << "\n";
} else {
if (sys::ExecuteAndWait(Objcopy, Args))
return createStringError(inconvertibleErrorCode(),
"'llvm-objcopy' tool failed");
}
return Error::success();
}
Expected<std::string> getHostBundle(StringRef Input) {
TempFileHandlerRAII TempFiles;
auto ModifiedObjPathOrErr = TempFiles.Create(std::nullopt);
if (!ModifiedObjPathOrErr)
return ModifiedObjPathOrErr.takeError();
StringRef ModifiedObjPath = *ModifiedObjPathOrErr;
BumpPtrAllocator Alloc;
StringSaver SS{Alloc};
SmallVector<StringRef, 16> ObjcopyArgs{"llvm-objcopy"};
ObjcopyArgs.push_back("--regex");
ObjcopyArgs.push_back("--remove-section=__CLANG_OFFLOAD_BUNDLE__.*");
ObjcopyArgs.push_back("--");
StringRef ObjcopyInputFileName;
if (StringRef(BundlerConfig.FilesType).starts_with("a")) {
auto InputFileOrErr =
TempFiles.Create(ArrayRef<char>(Input.data(), Input.size()));
if (!InputFileOrErr)
return InputFileOrErr.takeError();
ObjcopyInputFileName = *InputFileOrErr;
} else
ObjcopyInputFileName = BundlerConfig.InputFileNames.front();
ObjcopyArgs.push_back(ObjcopyInputFileName);
ObjcopyArgs.push_back(ModifiedObjPath);
if (Error Err = executeObjcopy(BundlerConfig.ObjcopyPath, ObjcopyArgs))
return std::move(Err);
auto BufOrErr = MemoryBuffer::getFile(ModifiedObjPath);
if (!BufOrErr)
return createStringError(BufOrErr.getError(),
"Failed to read back the modified object file");
return BufOrErr->get()->getBuffer().str();
}
};
class TextFileHandler final : public FileHandler {
StringRef Comment;
std::string BundleStartString;
std::string BundleEndString;
size_t ReadChars = 0u;
protected:
Error ReadHeader(MemoryBuffer &Input) final { return Error::success(); }
Expected<std::optional<StringRef>>
ReadBundleStart(MemoryBuffer &Input) final {
StringRef FC = Input.getBuffer();
ReadChars = FC.find(BundleStartString, ReadChars);
if (ReadChars == FC.npos)
return std::nullopt;
size_t TripleStart = ReadChars = ReadChars + BundleStartString.size();
size_t TripleEnd = ReadChars = FC.find("\n", ReadChars);
if (TripleEnd == FC.npos)
return std::nullopt;
++ReadChars;
return StringRef(&FC.data()[TripleStart], TripleEnd - TripleStart);
}
Error ReadBundleEnd(MemoryBuffer &Input) final {
StringRef FC = Input.getBuffer();
assert(FC[ReadChars] == '\n' && "The bundle should end with a new line.");
size_t TripleEnd = ReadChars = FC.find("\n", ReadChars + 1);
if (TripleEnd != FC.npos)
++ReadChars;
return Error::success();
}
Error ReadBundle(raw_ostream &OS, MemoryBuffer &Input) final {
StringRef FC = Input.getBuffer();
size_t BundleStart = ReadChars;
size_t BundleEnd = ReadChars = FC.find(BundleEndString, ReadChars);
StringRef Bundle(&FC.data()[BundleStart], BundleEnd - BundleStart);
OS << Bundle;
return Error::success();
}
Error WriteHeader(raw_ostream &OS,
ArrayRef<std::unique_ptr<MemoryBuffer>> Inputs) final {
return Error::success();
}
Error WriteBundleStart(raw_ostream &OS, StringRef TargetTriple) final {
OS << BundleStartString << TargetTriple << "\n";
return Error::success();
}
Error WriteBundleEnd(raw_ostream &OS, StringRef TargetTriple) final {
OS << BundleEndString << TargetTriple << "\n";
return Error::success();
}
Error WriteBundle(raw_ostream &OS, MemoryBuffer &Input) final {
OS << Input.getBuffer();
return Error::success();
}
public:
TextFileHandler(StringRef Comment) : Comment(Comment), ReadChars(0) {
BundleStartString =
"\n" + Comment.str() + " " OFFLOAD_BUNDLER_MAGIC_STR "__START__ ";
BundleEndString =
"\n" + Comment.str() + " " OFFLOAD_BUNDLER_MAGIC_STR "__END__ ";
}
Error listBundleIDsCallback(MemoryBuffer &Input,
const BundleInfo &Info) final {
ReadChars = Input.getBuffer().find(BundleEndString, ReadChars);
if (Error Err = ReadBundleEnd(Input))
return Err;
return Error::success();
}
};
}
static std::unique_ptr<FileHandler>
CreateObjectFileHandler(MemoryBuffer &FirstInput,
const OffloadBundlerConfig &BundlerConfig) {
Expected<std::unique_ptr<Binary>> BinaryOrErr = createBinary(FirstInput);
if (errorToBool(BinaryOrErr.takeError()) || !isa<ObjectFile>(*BinaryOrErr))
return std::make_unique<BinaryFileHandler>(BundlerConfig);
return std::make_unique<ObjectFileHandler>(
std::unique_ptr<ObjectFile>(cast<ObjectFile>(BinaryOrErr->release())),
BundlerConfig);
}
static Expected<std::unique_ptr<FileHandler>>
CreateFileHandler(MemoryBuffer &FirstInput,
const OffloadBundlerConfig &BundlerConfig) {
std::string FilesType = BundlerConfig.FilesType;
if (FilesType == "i")
return std::make_unique<TextFileHandler>("//");
if (FilesType == "ii")
return std::make_unique<TextFileHandler>("//");
if (FilesType == "cui")
return std::make_unique<TextFileHandler>("//");
if (FilesType == "hipi")
return std::make_unique<TextFileHandler>("//");
if (FilesType == "d")
return std::make_unique<TextFileHandler>("#");
if (FilesType == "ll")
return std::make_unique<TextFileHandler>(";");
if (FilesType == "bc")
return std::make_unique<BinaryFileHandler>(BundlerConfig);
if (FilesType == "s")
return std::make_unique<TextFileHandler>("#");
if (FilesType == "o")
return CreateObjectFileHandler(FirstInput, BundlerConfig);
if (FilesType == "a")
return CreateObjectFileHandler(FirstInput, BundlerConfig);
if (FilesType == "gch")
return std::make_unique<BinaryFileHandler>(BundlerConfig);
if (FilesType == "ast")
return std::make_unique<BinaryFileHandler>(BundlerConfig);
return createStringError(errc::invalid_argument,
"'" + FilesType + "': invalid file type specified");
}
OffloadBundlerConfig::OffloadBundlerConfig() {
if (llvm::compression::zstd::isAvailable()) {
CompressionFormat = llvm::compression::Format::Zstd;
CompressionLevel = 3;
} else if (llvm::compression::zlib::isAvailable()) {
CompressionFormat = llvm::compression::Format::Zlib;
CompressionLevel = llvm::compression::zlib::DefaultCompression;
}
auto IgnoreEnvVarOpt =
llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_IGNORE_ENV_VAR");
if (IgnoreEnvVarOpt.has_value() && IgnoreEnvVarOpt.value() == "1")
return;
auto VerboseEnvVarOpt = llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_VERBOSE");
if (VerboseEnvVarOpt.has_value())
Verbose = VerboseEnvVarOpt.value() == "1";
auto CompressEnvVarOpt =
llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_COMPRESS");
if (CompressEnvVarOpt.has_value())
Compress = CompressEnvVarOpt.value() == "1";
auto CompressionLevelEnvVarOpt =
llvm::sys::Process::GetEnv("OFFLOAD_BUNDLER_COMPRESSION_LEVEL");
if (CompressionLevelEnvVarOpt.has_value()) {
llvm::StringRef CompressionLevelStr = CompressionLevelEnvVarOpt.value();
int Level;
if (!CompressionLevelStr.getAsInteger(10, Level))
CompressionLevel = Level;
else
llvm::errs()
<< "Warning: Invalid value for OFFLOAD_BUNDLER_COMPRESSION_LEVEL: "
<< CompressionLevelStr.str() << ". Ignoring it.\n";
}
}
static std::string formatWithCommas(unsigned long long Value) {
std::string Num = std::to_string(Value);
int InsertPosition = Num.length() - 3;
while (InsertPosition > 0) {
Num.insert(InsertPosition, ",");
InsertPosition -= 3;
}
return Num;
}
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
CompressedOffloadBundle::compress(llvm::compression::Params P,
const llvm::MemoryBuffer &Input,
bool Verbose) {
if (!llvm::compression::zstd::isAvailable() &&
!llvm::compression::zlib::isAvailable())
return createStringError(llvm::inconvertibleErrorCode(),
"Compression not supported");
llvm::Timer HashTimer("Hash Calculation Timer", "Hash calculation time",
ClangOffloadBundlerTimerGroup);
if (Verbose)
HashTimer.startTimer();
llvm::MD5 Hash;
llvm::MD5::MD5Result Result;
Hash.update(Input.getBuffer());
Hash.final(Result);
uint64_t TruncatedHash = Result.low();
if (Verbose)
HashTimer.stopTimer();
SmallVector<uint8_t, 0> CompressedBuffer;
auto BufferUint8 = llvm::ArrayRef<uint8_t>(
reinterpret_cast<const uint8_t *>(Input.getBuffer().data()),
Input.getBuffer().size());
llvm::Timer CompressTimer("Compression Timer", "Compression time",
ClangOffloadBundlerTimerGroup);
if (Verbose)
CompressTimer.startTimer();
llvm::compression::compress(P, BufferUint8, CompressedBuffer);
if (Verbose)
CompressTimer.stopTimer();
uint16_t CompressionMethod = static_cast<uint16_t>(P.format);
uint32_t UncompressedSize = Input.getBuffer().size();
uint32_t TotalFileSize = MagicNumber.size() + sizeof(TotalFileSize) +
sizeof(Version) + sizeof(CompressionMethod) +
sizeof(UncompressedSize) + sizeof(TruncatedHash) +
CompressedBuffer.size();
SmallVector<char, 0> FinalBuffer;
llvm::raw_svector_ostream OS(FinalBuffer);
OS << MagicNumber;
OS.write(reinterpret_cast<const char *>(&Version), sizeof(Version));
OS.write(reinterpret_cast<const char *>(&CompressionMethod),
sizeof(CompressionMethod));
OS.write(reinterpret_cast<const char *>(&TotalFileSize),
sizeof(TotalFileSize));
OS.write(reinterpret_cast<const char *>(&UncompressedSize),
sizeof(UncompressedSize));
OS.write(reinterpret_cast<const char *>(&TruncatedHash),
sizeof(TruncatedHash));
OS.write(reinterpret_cast<const char *>(CompressedBuffer.data()),
CompressedBuffer.size());
if (Verbose) {
auto MethodUsed =
P.format == llvm::compression::Format::Zstd ? "zstd" : "zlib";
double CompressionRate =
static_cast<double>(UncompressedSize) / CompressedBuffer.size();
double CompressionTimeSeconds = CompressTimer.getTotalTime().getWallTime();
double CompressionSpeedMBs =
(UncompressedSize / (1024.0 * 1024.0)) / CompressionTimeSeconds;
llvm::errs() << "Compressed bundle format version: " << Version << "\n"
<< "Total file size (including headers): "
<< formatWithCommas(TotalFileSize) << " bytes\n"
<< "Compression method used: " << MethodUsed << "\n"
<< "Compression level: " << P.level << "\n"
<< "Binary size before compression: "
<< formatWithCommas(UncompressedSize) << " bytes\n"
<< "Binary size after compression: "
<< formatWithCommas(CompressedBuffer.size()) << " bytes\n"
<< "Compression rate: "
<< llvm::format("%.2lf", CompressionRate) << "\n"
<< "Compression ratio: "
<< llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
<< "Compression speed: "
<< llvm::format("%.2lf MB/s", CompressionSpeedMBs) << "\n"
<< "Truncated MD5 hash: "
<< llvm::format_hex(TruncatedHash, 16) << "\n";
}
return llvm::MemoryBuffer::getMemBufferCopy(
llvm::StringRef(FinalBuffer.data(), FinalBuffer.size()));
}
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
CompressedOffloadBundle::decompress(const llvm::MemoryBuffer &Input,
bool Verbose) {
StringRef Blob = Input.getBuffer();
if (Blob.size() < V1HeaderSize)
return llvm::MemoryBuffer::getMemBufferCopy(Blob);
if (llvm::identify_magic(Blob) !=
llvm::file_magic::offload_bundle_compressed) {
if (Verbose)
llvm::errs() << "Uncompressed bundle.\n";
return llvm::MemoryBuffer::getMemBufferCopy(Blob);
}
size_t CurrentOffset = MagicSize;
uint16_t ThisVersion;
memcpy(&ThisVersion, Blob.data() + CurrentOffset, sizeof(uint16_t));
CurrentOffset += VersionFieldSize;
uint16_t CompressionMethod;
memcpy(&CompressionMethod, Blob.data() + CurrentOffset, sizeof(uint16_t));
CurrentOffset += MethodFieldSize;
uint32_t TotalFileSize;
if (ThisVersion >= 2) {
if (Blob.size() < V2HeaderSize)
return createStringError(inconvertibleErrorCode(),
"Compressed bundle header size too small");
memcpy(&TotalFileSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
CurrentOffset += FileSizeFieldSize;
}
uint32_t UncompressedSize;
memcpy(&UncompressedSize, Blob.data() + CurrentOffset, sizeof(uint32_t));
CurrentOffset += UncompressedSizeFieldSize;
uint64_t StoredHash;
memcpy(&StoredHash, Blob.data() + CurrentOffset, sizeof(uint64_t));
CurrentOffset += HashFieldSize;
llvm::compression::Format CompressionFormat;
if (CompressionMethod ==
static_cast<uint16_t>(llvm::compression::Format::Zlib))
CompressionFormat = llvm::compression::Format::Zlib;
else if (CompressionMethod ==
static_cast<uint16_t>(llvm::compression::Format::Zstd))
CompressionFormat = llvm::compression::Format::Zstd;
else
return createStringError(inconvertibleErrorCode(),
"Unknown compressing method");
llvm::Timer DecompressTimer("Decompression Timer", "Decompression time",
ClangOffloadBundlerTimerGroup);
if (Verbose)
DecompressTimer.startTimer();
SmallVector<uint8_t, 0> DecompressedData;
StringRef CompressedData = Blob.substr(CurrentOffset);
if (llvm::Error DecompressionError = llvm::compression::decompress(
CompressionFormat, llvm::arrayRefFromStringRef(CompressedData),
DecompressedData, UncompressedSize))
return createStringError(inconvertibleErrorCode(),
"Could not decompress embedded file contents: " +
llvm::toString(std::move(DecompressionError)));
if (Verbose) {
DecompressTimer.stopTimer();
double DecompressionTimeSeconds =
DecompressTimer.getTotalTime().getWallTime();
llvm::Timer HashRecalcTimer("Hash Recalculation Timer",
"Hash recalculation time",
ClangOffloadBundlerTimerGroup);
HashRecalcTimer.startTimer();
llvm::MD5 Hash;
llvm::MD5::MD5Result Result;
Hash.update(llvm::ArrayRef<uint8_t>(DecompressedData.data(),
DecompressedData.size()));
Hash.final(Result);
uint64_t RecalculatedHash = Result.low();
HashRecalcTimer.stopTimer();
bool HashMatch = (StoredHash == RecalculatedHash);
double CompressionRate =
static_cast<double>(UncompressedSize) / CompressedData.size();
double DecompressionSpeedMBs =
(UncompressedSize / (1024.0 * 1024.0)) / DecompressionTimeSeconds;
llvm::errs() << "Compressed bundle format version: " << ThisVersion << "\n";
if (ThisVersion >= 2)
llvm::errs() << "Total file size (from header): "
<< formatWithCommas(TotalFileSize) << " bytes\n";
llvm::errs() << "Decompression method: "
<< (CompressionFormat == llvm::compression::Format::Zlib
? "zlib"
: "zstd")
<< "\n"
<< "Size before decompression: "
<< formatWithCommas(CompressedData.size()) << " bytes\n"
<< "Size after decompression: "
<< formatWithCommas(UncompressedSize) << " bytes\n"
<< "Compression rate: "
<< llvm::format("%.2lf", CompressionRate) << "\n"
<< "Compression ratio: "
<< llvm::format("%.2lf%%", 100.0 / CompressionRate) << "\n"
<< "Decompression speed: "
<< llvm::format("%.2lf MB/s", DecompressionSpeedMBs) << "\n"
<< "Stored hash: " << llvm::format_hex(StoredHash, 16) << "\n"
<< "Recalculated hash: "
<< llvm::format_hex(RecalculatedHash, 16) << "\n"
<< "Hashes match: " << (HashMatch ? "Yes" : "No") << "\n";
}
return llvm::MemoryBuffer::getMemBufferCopy(
llvm::toStringRef(DecompressedData));
}
Error OffloadBundler::ListBundleIDsInFile(
StringRef InputFileName, const OffloadBundlerConfig &BundlerConfig) {
ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr =
MemoryBuffer::getFileOrSTDIN(InputFileName);
if (std::error_code EC = CodeOrErr.getError())
return createFileError(InputFileName, EC);
Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
CompressedOffloadBundle::decompress(**CodeOrErr, BundlerConfig.Verbose);
if (!DecompressedBufferOrErr)
return createStringError(
inconvertibleErrorCode(),
"Failed to decompress input: " +
llvm::toString(DecompressedBufferOrErr.takeError()));
MemoryBuffer &DecompressedInput = **DecompressedBufferOrErr;
Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
CreateFileHandler(DecompressedInput, BundlerConfig);
if (!FileHandlerOrErr)
return FileHandlerOrErr.takeError();
std::unique_ptr<FileHandler> &FH = *FileHandlerOrErr;
assert(FH);
return FH->listBundleIDs(DecompressedInput);
}
bool isCodeObjectCompatible(const OffloadTargetInfo &CodeObjectInfo,
const OffloadTargetInfo &TargetInfo) {
if (CodeObjectInfo == TargetInfo) {
DEBUG_WITH_TYPE("CodeObjectCompatibility",
dbgs() << "Compatible: Exact match: \t[CodeObject: "
<< CodeObjectInfo.str()
<< "]\t:\t[Target: " << TargetInfo.str() << "]\n");
return true;
}
if (!CodeObjectInfo.isOffloadKindCompatible(TargetInfo.OffloadKind) ||
!CodeObjectInfo.Triple.isCompatibleWith(TargetInfo.Triple)) {
DEBUG_WITH_TYPE(
"CodeObjectCompatibility",
dbgs() << "Incompatible: Kind/Triple mismatch \t[CodeObject: "
<< CodeObjectInfo.str() << "]\t:\t[Target: " << TargetInfo.str()
<< "]\n");
return false;
}
llvm::StringMap<bool> CodeObjectFeatureMap, TargetFeatureMap;
std::optional<StringRef> CodeObjectProc = clang::parseTargetID(
CodeObjectInfo.Triple, CodeObjectInfo.TargetID, &CodeObjectFeatureMap);
std::optional<StringRef> TargetProc = clang::parseTargetID(
TargetInfo.Triple, TargetInfo.TargetID, &TargetFeatureMap);
if (!TargetProc || !CodeObjectProc ||
CodeObjectProc.value() != TargetProc.value()) {
DEBUG_WITH_TYPE("CodeObjectCompatibility",
dbgs() << "Incompatible: Processor mismatch \t[CodeObject: "
<< CodeObjectInfo.str()
<< "]\t:\t[Target: " << TargetInfo.str() << "]\n");
return false;
}
if (CodeObjectFeatureMap.getNumItems() > TargetFeatureMap.getNumItems()) {
DEBUG_WITH_TYPE("CodeObjectCompatibility",
dbgs() << "Incompatible: CodeObject has more features "
"than target \t[CodeObject: "
<< CodeObjectInfo.str()
<< "]\t:\t[Target: " << TargetInfo.str() << "]\n");
return false;
}
for (const auto &CodeObjectFeature : CodeObjectFeatureMap) {
auto TargetFeature = TargetFeatureMap.find(CodeObjectFeature.getKey());
if (TargetFeature == TargetFeatureMap.end()) {
DEBUG_WITH_TYPE(
"CodeObjectCompatibility",
dbgs()
<< "Incompatible: Value of CodeObject's non-ANY feature is "
"not matching with Target feature's ANY value \t[CodeObject: "
<< CodeObjectInfo.str() << "]\t:\t[Target: " << TargetInfo.str()
<< "]\n");
return false;
} else if (TargetFeature->getValue() != CodeObjectFeature.getValue()) {
DEBUG_WITH_TYPE(
"CodeObjectCompatibility",
dbgs() << "Incompatible: Value of CodeObject's non-ANY feature is "
"not matching with Target feature's non-ANY value "
"\t[CodeObject: "
<< CodeObjectInfo.str()
<< "]\t:\t[Target: " << TargetInfo.str() << "]\n");
return false;
}
}
DEBUG_WITH_TYPE(
"CodeObjectCompatibility",
dbgs() << "Compatible: Target IDs are compatible \t[CodeObject: "
<< CodeObjectInfo.str() << "]\t:\t[Target: " << TargetInfo.str()
<< "]\n");
return true;
}
Error OffloadBundler::BundleFiles() {
std::error_code EC;
SmallVector<char, 0> Buffer;
llvm::raw_svector_ostream BufferStream(Buffer);
SmallVector<std::unique_ptr<MemoryBuffer>, 8u> InputBuffers;
InputBuffers.reserve(BundlerConfig.InputFileNames.size());
for (auto &I : BundlerConfig.InputFileNames) {
ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr =
MemoryBuffer::getFileOrSTDIN(I);
if (std::error_code EC = CodeOrErr.getError())
return createFileError(I, EC);
InputBuffers.emplace_back(std::move(*CodeOrErr));
}
assert((BundlerConfig.HostInputIndex != ~0u || BundlerConfig.AllowNoHost) &&
"Host input index undefined??");
Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr = CreateFileHandler(
*InputBuffers[BundlerConfig.AllowNoHost ? 0
: BundlerConfig.HostInputIndex],
BundlerConfig);
if (!FileHandlerOrErr)
return FileHandlerOrErr.takeError();
std::unique_ptr<FileHandler> &FH = *FileHandlerOrErr;
assert(FH);
if (Error Err = FH->WriteHeader(BufferStream, InputBuffers))
return Err;
auto Input = InputBuffers.begin();
for (auto &Triple : BundlerConfig.TargetNames) {
if (Error Err = FH->WriteBundleStart(BufferStream, Triple))
return Err;
if (Error Err = FH->WriteBundle(BufferStream, **Input))
return Err;
if (Error Err = FH->WriteBundleEnd(BufferStream, Triple))
return Err;
++Input;
}
raw_fd_ostream OutputFile(BundlerConfig.OutputFileNames.front(), EC,
sys::fs::OF_None);
if (EC)
return createFileError(BundlerConfig.OutputFileNames.front(), EC);
SmallVector<char, 0> CompressedBuffer;
if (BundlerConfig.Compress) {
std::unique_ptr<llvm::MemoryBuffer> BufferMemory =
llvm::MemoryBuffer::getMemBufferCopy(
llvm::StringRef(Buffer.data(), Buffer.size()));
auto CompressionResult = CompressedOffloadBundle::compress(
{BundlerConfig.CompressionFormat, BundlerConfig.CompressionLevel,
true},
*BufferMemory, BundlerConfig.Verbose);
if (auto Error = CompressionResult.takeError())
return Error;
auto CompressedMemBuffer = std::move(CompressionResult.get());
CompressedBuffer.assign(CompressedMemBuffer->getBufferStart(),
CompressedMemBuffer->getBufferEnd());
} else
CompressedBuffer = Buffer;
OutputFile.write(CompressedBuffer.data(), CompressedBuffer.size());
return FH->finalizeOutputFile();
}
Error OffloadBundler::UnbundleFiles() {
ErrorOr<std::unique_ptr<MemoryBuffer>> CodeOrErr =
MemoryBuffer::getFileOrSTDIN(BundlerConfig.InputFileNames.front());
if (std::error_code EC = CodeOrErr.getError())
return createFileError(BundlerConfig.InputFileNames.front(), EC);
Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
CompressedOffloadBundle::decompress(**CodeOrErr, BundlerConfig.Verbose);
if (!DecompressedBufferOrErr)
return createStringError(
inconvertibleErrorCode(),
"Failed to decompress input: " +
llvm::toString(DecompressedBufferOrErr.takeError()));
MemoryBuffer &Input = **DecompressedBufferOrErr;
Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
CreateFileHandler(Input, BundlerConfig);
if (!FileHandlerOrErr)
return FileHandlerOrErr.takeError();
std::unique_ptr<FileHandler> &FH = *FileHandlerOrErr;
assert(FH);
if (Error Err = FH->ReadHeader(Input))
return Err;
StringMap<StringRef> Worklist;
auto Output = BundlerConfig.OutputFileNames.begin();
for (auto &Triple : BundlerConfig.TargetNames) {
Worklist[Triple] = *Output;
++Output;
}
bool FoundHostBundle = false;
while (!Worklist.empty()) {
Expected<std::optional<StringRef>> CurTripleOrErr =
FH->ReadBundleStart(Input);
if (!CurTripleOrErr)
return CurTripleOrErr.takeError();
if (!*CurTripleOrErr)
break;
StringRef CurTriple = **CurTripleOrErr;
assert(!CurTriple.empty());
auto Output = Worklist.begin();
for (auto E = Worklist.end(); Output != E; Output++) {
if (isCodeObjectCompatible(
OffloadTargetInfo(CurTriple, BundlerConfig),
OffloadTargetInfo((*Output).first(), BundlerConfig))) {
break;
}
}
if (Output == Worklist.end())
continue;
std::error_code EC;
raw_fd_ostream OutputFile((*Output).second, EC, sys::fs::OF_None);
if (EC)
return createFileError((*Output).second, EC);
if (Error Err = FH->ReadBundle(OutputFile, Input))
return Err;
if (Error Err = FH->ReadBundleEnd(Input))
return Err;
Worklist.erase(Output);
auto OffloadInfo = OffloadTargetInfo(CurTriple, BundlerConfig);
if (OffloadInfo.hasHostKind())
FoundHostBundle = true;
}
if (!BundlerConfig.AllowMissingBundles && !Worklist.empty()) {
std::string ErrMsg = "Can't find bundles for";
std::set<StringRef> Sorted;
for (auto &E : Worklist)
Sorted.insert(E.first());
unsigned I = 0;
unsigned Last = Sorted.size() - 1;
for (auto &E : Sorted) {
if (I != 0 && Last > 1)
ErrMsg += ",";
ErrMsg += " ";
if (I == Last && I != 0)
ErrMsg += "and ";
ErrMsg += E.str();
++I;
}
return createStringError(inconvertibleErrorCode(), ErrMsg);
}
if (Worklist.size() == BundlerConfig.TargetNames.size()) {
for (auto &E : Worklist) {
std::error_code EC;
raw_fd_ostream OutputFile(E.second, EC, sys::fs::OF_None);
if (EC)
return createFileError(E.second, EC);
auto OffloadInfo = OffloadTargetInfo(E.getKey(), BundlerConfig);
if (OffloadInfo.hasHostKind())
OutputFile.write(Input.getBufferStart(), Input.getBufferSize());
}
return Error::success();
}
if (!(FoundHostBundle || BundlerConfig.HostInputIndex == ~0u ||
BundlerConfig.AllowMissingBundles))
return createStringError(inconvertibleErrorCode(),
"Can't find bundle for the host target");
for (auto &E : Worklist) {
std::error_code EC;
raw_fd_ostream OutputFile(E.second, EC, sys::fs::OF_None);
if (EC)
return createFileError(E.second, EC);
}
return Error::success();
}
static Archive::Kind getDefaultArchiveKindForHost() {
return Triple(sys::getDefaultTargetTriple()).isOSDarwin() ? Archive::K_DARWIN
: Archive::K_GNU;
}
static bool
getCompatibleOffloadTargets(OffloadTargetInfo &CodeObjectInfo,
SmallVectorImpl<StringRef> &CompatibleTargets,
const OffloadBundlerConfig &BundlerConfig) {
if (!CompatibleTargets.empty()) {
DEBUG_WITH_TYPE("CodeObjectCompatibility",
dbgs() << "CompatibleTargets list should be empty\n");
return false;
}
for (auto &Target : BundlerConfig.TargetNames) {
auto TargetInfo = OffloadTargetInfo(Target, BundlerConfig);
if (isCodeObjectCompatible(CodeObjectInfo, TargetInfo))
CompatibleTargets.push_back(Target);
}
return !CompatibleTargets.empty();
}
static Error
CheckHeterogeneousArchive(StringRef ArchiveName,
const OffloadBundlerConfig &BundlerConfig) {
std::vector<std::unique_ptr<MemoryBuffer>> ArchiveBuffers;
ErrorOr<std::unique_ptr<MemoryBuffer>> BufOrErr =
MemoryBuffer::getFileOrSTDIN(ArchiveName, true, false);
if (std::error_code EC = BufOrErr.getError())
return createFileError(ArchiveName, EC);
ArchiveBuffers.push_back(std::move(*BufOrErr));
Expected<std::unique_ptr<llvm::object::Archive>> LibOrErr =
Archive::create(ArchiveBuffers.back()->getMemBufferRef());
if (!LibOrErr)
return LibOrErr.takeError();
auto Archive = std::move(*LibOrErr);
Error ArchiveErr = Error::success();
auto ChildEnd = Archive->child_end();
for (auto ArchiveIter = Archive->child_begin(ArchiveErr);
ArchiveIter != ChildEnd; ++ArchiveIter) {
if (ArchiveErr)
return ArchiveErr;
auto ArchiveChildNameOrErr = (*ArchiveIter).getName();
if (!ArchiveChildNameOrErr)
return ArchiveChildNameOrErr.takeError();
auto CodeObjectBufferRefOrErr = (*ArchiveIter).getMemoryBufferRef();
if (!CodeObjectBufferRefOrErr)
return CodeObjectBufferRefOrErr.takeError();
auto CodeObjectBuffer =
MemoryBuffer::getMemBuffer(*CodeObjectBufferRefOrErr, false);
Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
CreateFileHandler(*CodeObjectBuffer, BundlerConfig);
if (!FileHandlerOrErr)
return FileHandlerOrErr.takeError();
std::unique_ptr<FileHandler> &FileHandler = *FileHandlerOrErr;
assert(FileHandler);
std::set<StringRef> BundleIds;
auto CodeObjectFileError =
FileHandler->getBundleIDs(*CodeObjectBuffer, BundleIds);
if (CodeObjectFileError)
return CodeObjectFileError;
auto &&ConflictingArchs = clang::getConflictTargetIDCombination(BundleIds);
if (ConflictingArchs) {
std::string ErrMsg =
Twine("conflicting TargetIDs [" + ConflictingArchs.value().first +
", " + ConflictingArchs.value().second + "] found in " +
ArchiveChildNameOrErr.get() + " of " + ArchiveName)
.str();
return createStringError(inconvertibleErrorCode(), ErrMsg);
}
}
return ArchiveErr;
}
Error OffloadBundler::UnbundleArchive() {
std::vector<std::unique_ptr<MemoryBuffer>> ArchiveBuffers;
StringMap<std::vector<NewArchiveMember>> OutputArchivesMap;
StringMap<StringRef> TargetOutputFileNameMap;
auto Output = BundlerConfig.OutputFileNames.begin();
for (auto &Target : BundlerConfig.TargetNames) {
TargetOutputFileNameMap[Target] = *Output;
++Output;
}
StringRef IFName = BundlerConfig.InputFileNames.front();
if (BundlerConfig.CheckInputArchive) {
auto ArchiveError = CheckHeterogeneousArchive(IFName, BundlerConfig);
if (ArchiveError) {
return ArchiveError;
}
}
ErrorOr<std::unique_ptr<MemoryBuffer>> BufOrErr =
MemoryBuffer::getFileOrSTDIN(IFName, true, false);
if (std::error_code EC = BufOrErr.getError())
return createFileError(BundlerConfig.InputFileNames.front(), EC);
ArchiveBuffers.push_back(std::move(*BufOrErr));
Expected<std::unique_ptr<llvm::object::Archive>> LibOrErr =
Archive::create(ArchiveBuffers.back()->getMemBufferRef());
if (!LibOrErr)
return LibOrErr.takeError();
auto Archive = std::move(*LibOrErr);
Error ArchiveErr = Error::success();
auto ChildEnd = Archive->child_end();
for (auto ArchiveIter = Archive->child_begin(ArchiveErr);
ArchiveIter != ChildEnd; ++ArchiveIter) {
if (ArchiveErr)
return ArchiveErr;
auto ArchiveChildNameOrErr = (*ArchiveIter).getName();
if (!ArchiveChildNameOrErr)
return ArchiveChildNameOrErr.takeError();
StringRef BundledObjectFile = sys::path::filename(*ArchiveChildNameOrErr);
auto CodeObjectBufferRefOrErr = (*ArchiveIter).getMemoryBufferRef();
if (!CodeObjectBufferRefOrErr)
return CodeObjectBufferRefOrErr.takeError();
auto TempCodeObjectBuffer =
MemoryBuffer::getMemBuffer(*CodeObjectBufferRefOrErr, false);
Expected<std::unique_ptr<MemoryBuffer>> DecompressedBufferOrErr =
CompressedOffloadBundle::decompress(*TempCodeObjectBuffer,
BundlerConfig.Verbose);
if (!DecompressedBufferOrErr)
return createStringError(
inconvertibleErrorCode(),
"Failed to decompress code object: " +
llvm::toString(DecompressedBufferOrErr.takeError()));
MemoryBuffer &CodeObjectBuffer = **DecompressedBufferOrErr;
Expected<std::unique_ptr<FileHandler>> FileHandlerOrErr =
CreateFileHandler(CodeObjectBuffer, BundlerConfig);
if (!FileHandlerOrErr)
return FileHandlerOrErr.takeError();
std::unique_ptr<FileHandler> &FileHandler = *FileHandlerOrErr;
assert(FileHandler &&
"FileHandle creation failed for file in the archive!");
if (Error ReadErr = FileHandler->ReadHeader(CodeObjectBuffer))
return ReadErr;
Expected<std::optional<StringRef>> CurBundleIDOrErr =
FileHandler->ReadBundleStart(CodeObjectBuffer);
if (!CurBundleIDOrErr)
return CurBundleIDOrErr.takeError();
std::optional<StringRef> OptionalCurBundleID = *CurBundleIDOrErr;
if (!OptionalCurBundleID)
continue;
StringRef CodeObject = *OptionalCurBundleID;
while (!CodeObject.empty()) {
SmallVector<StringRef> CompatibleTargets;
auto CodeObjectInfo = OffloadTargetInfo(CodeObject, BundlerConfig);
if (getCompatibleOffloadTargets(CodeObjectInfo, CompatibleTargets,
BundlerConfig)) {
std::string BundleData;
raw_string_ostream DataStream(BundleData);
if (Error Err = FileHandler->ReadBundle(DataStream, CodeObjectBuffer))
return Err;
for (auto &CompatibleTarget : CompatibleTargets) {
SmallString<128> BundledObjectFileName;
BundledObjectFileName.assign(BundledObjectFile);
auto OutputBundleName =
Twine(llvm::sys::path::stem(BundledObjectFileName) + "-" +
CodeObject +
getDeviceLibraryFileName(BundledObjectFileName,
CodeObjectInfo.TargetID))
.str();
std::replace(OutputBundleName.begin(), OutputBundleName.end(), ':',
'_');
std::unique_ptr<MemoryBuffer> MemBuf = MemoryBuffer::getMemBufferCopy(
DataStream.str(), OutputBundleName);
ArchiveBuffers.push_back(std::move(MemBuf));
llvm::MemoryBufferRef MemBufRef =
MemoryBufferRef(*(ArchiveBuffers.back()));
if (!OutputArchivesMap.contains(CompatibleTarget)) {
std::vector<NewArchiveMember> ArchiveMembers;
ArchiveMembers.push_back(NewArchiveMember(MemBufRef));
OutputArchivesMap.insert_or_assign(CompatibleTarget,
std::move(ArchiveMembers));
} else {
OutputArchivesMap[CompatibleTarget].push_back(
NewArchiveMember(MemBufRef));
}
}
}
if (Error Err = FileHandler->ReadBundleEnd(CodeObjectBuffer))
return Err;
Expected<std::optional<StringRef>> NextTripleOrErr =
FileHandler->ReadBundleStart(CodeObjectBuffer);
if (!NextTripleOrErr)
return NextTripleOrErr.takeError();
CodeObject = ((*NextTripleOrErr).has_value()) ? **NextTripleOrErr : "";
}
}
assert(!ArchiveErr && "Error occurred while reading archive!");
for (auto &Target : BundlerConfig.TargetNames) {
StringRef FileName = TargetOutputFileNameMap[Target];
StringMapIterator<std::vector<llvm::NewArchiveMember>> CurArchiveMembers =
OutputArchivesMap.find(Target);
if (CurArchiveMembers != OutputArchivesMap.end()) {
if (Error WriteErr = writeArchive(FileName, CurArchiveMembers->getValue(),
SymtabWritingMode::NormalSymtab,
getDefaultArchiveKindForHost(), true,
false, nullptr))
return WriteErr;
} else if (!BundlerConfig.AllowMissingBundles) {
std::string ErrMsg =
Twine("no compatible code object found for the target '" + Target +
"' in heterogeneous archive library: " + IFName)
.str();
return createStringError(inconvertibleErrorCode(), ErrMsg);
} else {
std::vector<llvm::NewArchiveMember> EmptyArchive;
EmptyArchive.clear();
if (Error WriteErr = writeArchive(
FileName, EmptyArchive, SymtabWritingMode::NormalSymtab,
getDefaultArchiveKindForHost(), true, false, nullptr))
return WriteErr;
}
}
return Error::success();
}