#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
#include "triton/Target/LLVMIR/Passes.h"
#include "llvm/BinaryFormat/Dwarf.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Path.h"
using namespace mlir;
#define GEN_PASS_CLASSES
#include "triton/Target/LLVMIR/Passes.h.inc"
namespace {
FileLineColLoc extractFileLoc(Location loc) {
if (auto fileLoc = dyn_cast<FileLineColLoc>(loc))
return fileLoc;
if (auto nameLoc = dyn_cast<NameLoc>(loc))
return extractFileLoc(nameLoc.getChildLoc());
if (auto opaqueLoc = dyn_cast<OpaqueLoc>(loc))
return extractFileLoc(opaqueLoc.getFallbackLocation());
if (auto fusedLoc = dyn_cast<FusedLoc>(loc))
return extractFileLoc(fusedLoc.getLocations().front());
if (auto callerLoc = dyn_cast<CallSiteLoc>(loc))
return extractFileLoc(callerLoc.getCaller());
StringAttr unknownFile = mlir::StringAttr::get(loc.getContext(), "<unknown>");
return mlir::FileLineColLoc::get(unknownFile, 0, 0);
}
struct LLVMDIScopePass : public LLVMDIScopeBase<LLVMDIScopePass> {
LLVMDIScopePass() = default;
void setSubprogramAttr(LLVM::LLVMFuncOp funcOp) {
Location loc = funcOp.getLoc();
if (loc->findInstanceOf<mlir::FusedLocWith<LLVM::DISubprogramAttr>>())
return;
MLIRContext *context = &getContext();
LLVM::DICompileUnitAttr compileUnitAttr;
if (ModuleOp module = funcOp->getParentOfType<ModuleOp>()) {
auto fusedCompileUnitAttr =
module->getLoc()
->findInstanceOf<mlir::FusedLocWith<LLVM::DICompileUnitAttr>>();
if (fusedCompileUnitAttr)
compileUnitAttr = fusedCompileUnitAttr.getMetadata();
}
LLVM::DIFileAttr fileAttr;
int64_t line = 1, col = 1;
FileLineColLoc fileLoc = extractFileLoc(loc);
if (!fileLoc && compileUnitAttr) {
fileAttr = compileUnitAttr.getFile();
} else if (!fileLoc) {
fileAttr = LLVM::DIFileAttr::get(context, "<unknown>", "");
} else {
line = fileLoc.getLine();
col = fileLoc.getColumn();
StringRef inputFilePath = fileLoc.getFilename().getValue();
fileAttr = LLVM::DIFileAttr::get(
context, llvm::sys::path::filename(inputFilePath),
llvm::sys::path::parent_path(inputFilePath));
}
auto subroutineTypeAttr =
LLVM::DISubroutineTypeAttr::get(context, llvm::dwarf::DW_CC_normal, {});
DistinctAttr distinctId;
auto subprogramFlags = LLVM::DISubprogramFlags::Optimized;
if (!funcOp.isExternal()) {
distinctId = mlir::DistinctAttr::create(mlir::UnitAttr::get(context));
if (!compileUnitAttr) {
compileUnitAttr = LLVM::DICompileUnitAttr::get(
distinctId, llvm::dwarf::DW_LANG_C, fileAttr,
StringAttr::get(context, "triton"),
true, LLVM::DIEmissionKind::LineTablesOnly);
}
subprogramFlags = subprogramFlags | LLVM::DISubprogramFlags::Definition;
} else {
compileUnitAttr = {};
}
StringAttr funcNameAttr = funcOp.getNameAttr();
auto subprogramAttr = LLVM::DISubprogramAttr::get(
context, distinctId, compileUnitAttr, fileAttr, funcNameAttr,
funcNameAttr, fileAttr, line, line,
subprogramFlags, subroutineTypeAttr, {},
{});
funcOp->setLoc(FusedLoc::get(context, {loc}, subprogramAttr));
}
Location getNestedLoc(Operation *op, LLVM::DIScopeAttr scopeAttr,
Location calleeLoc) {
auto calleeFileName = extractFileLoc(calleeLoc).getFilename();
auto context = op->getContext();
LLVM::DIFileAttr calleeFileAttr = LLVM::DIFileAttr::get(
context, llvm::sys::path::filename(calleeFileName),
llvm::sys::path::parent_path(calleeFileName));
auto lexicalBlockFileAttr = LLVM::DILexicalBlockFileAttr::get(
context, scopeAttr, calleeFileAttr, 0);
Location loc = calleeLoc;
if (mlir::isa<CallSiteLoc>(calleeLoc)) {
auto nestedLoc = mlir::cast<CallSiteLoc>(calleeLoc).getCallee();
loc = getNestedLoc(op, lexicalBlockFileAttr, nestedLoc);
}
return FusedLoc::get(context, {loc}, lexicalBlockFileAttr);
}
void setLexicalBlockFileAttr(Operation *op) {
auto opLoc = op->getLoc();
if (auto callSiteLoc = dyn_cast<CallSiteLoc>(opLoc)) {
auto callerLoc = callSiteLoc.getCaller();
auto calleeLoc = callSiteLoc.getCallee();
LLVM::DIScopeAttr scopeAttr;
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
auto funcOpLoc = mlir::cast<FusedLoc>(funcOp.getLoc());
scopeAttr = mlir::cast<LLVM::DISubprogramAttr>(funcOpLoc.getMetadata());
auto loc =
CallSiteLoc::get(getNestedLoc(op, scopeAttr, calleeLoc), callerLoc);
op->setLoc(loc);
}
}
void runOnOperation() override {
getOperation()->walk<WalkOrder::PreOrder>([&](Operation *op) -> void {
if (isa<LLVM::LLVMFuncOp>(op))
setSubprogramAttr(cast<LLVM::LLVMFuncOp>(op));
else
setLexicalBlockFileAttr(op);
});
}
};
}
std::unique_ptr<Pass> mlir::createLLVMDIScopePass() {
return std::make_unique<LLVMDIScopePass>();
}