#include <optional>
#include <pybind11/functional.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "mlir/Bytecode/BytecodeWriter.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/UB/IR/UBOps.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Support/FileUtilities.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/Support/SourceMgr.h"
#include "ir.h"
namespace {
namespace py = pybind11;
using namespace mlir;
using namespace triton;
std::string locationToString(Location loc) {
std::string str;
llvm::raw_string_ostream os(str);
loc.print(os);
os.flush();
return str;
}
void outputWarning(Location loc, const std::string &msg) {
std::string locStr = locationToString(loc);
PyErr_WarnEx(PyExc_UserWarning, (locStr + ": " + msg).c_str(),
2);
}
}
namespace ir {
static py::class_<TritonOpBuilder> *builderClassPtr = nullptr;
py::class_<TritonOpBuilder> *getBuilderClass() { return builderClassPtr; }
}
void init_triton_ir(py::module &&m) {
using ret = py::return_value_policy;
using namespace pybind11::literals;
py::enum_<PaddingOption>(m, "PADDING_OPTION", py::module_local())
.value("PAD_ZERO", PaddingOption::PAD_ZERO)
.value("PAD_NAN", PaddingOption::PAD_NAN)
.export_values();
py::enum_<CacheModifier>(m, "CACHE_MODIFIER", py::module_local())
.value("NONE", CacheModifier::NONE)
.value("CA", CacheModifier::CA)
.value("CG", CacheModifier::CG)
.value("WB", CacheModifier::WB)
.value("CS", CacheModifier::CS)
.value("WT", CacheModifier::WT)
.value("CV", CacheModifier::CV)
.export_values();
py::enum_<MemSemantic>(m, "MEM_SEMANTIC", py::module_local())
.value("ACQUIRE_RELEASE", MemSemantic::ACQUIRE_RELEASE)
.value("ACQUIRE", MemSemantic::ACQUIRE)
.value("RELEASE", MemSemantic::RELEASE)
.value("RELAXED", MemSemantic::RELAXED)
.export_values();
py::enum_<MemSyncScope>(m, "MEM_SYNC_SCOPE", py::module_local())
.value("GPU", MemSyncScope::GPU)
.value("CTA", MemSyncScope::CTA)
.value("SYSTEM", MemSyncScope::SYSTEM)
.export_values();
py::enum_<EvictionPolicy>(m, "EVICTION_POLICY", py::module_local())
.value("NORMAL", EvictionPolicy::NORMAL)
.value("EVICT_FIRST", EvictionPolicy::EVICT_FIRST)
.value("EVICT_LAST", EvictionPolicy::EVICT_LAST)
.export_values();
py::enum_<RMWOp>(m, "ATOMIC_OP", py::module_local())
.value("ADD", RMWOp::ADD)
.value("FADD", RMWOp::FADD)
.value("AND", RMWOp::AND)
.value("OR", RMWOp::OR)
.value("XOR", RMWOp::XOR)
.value("XCHG", RMWOp::XCHG)
.value("MAX", RMWOp::MAX)
.value("MIN", RMWOp::MIN)
.value("UMIN", RMWOp::UMIN)
.value("UMAX", RMWOp::UMAX);
py::enum_<RoundingMode>(m, "ROUNDING_MODE", py::module_local())
.value("RTZ", RoundingMode::RTZ)
.value("RTNE", RoundingMode::RTNE);
py::enum_<PropagateNan>(m, "PROPAGATE_NAN", py::module_local())
.value("NONE", PropagateNan::NONE)
.value("ALL", PropagateNan::ALL);
py::enum_<InputPrecision>(m, "INPUT_PRECISION", py::module_local())
.value("TF32", InputPrecision::TF32)
.value("TF32x3", InputPrecision::TF32x3)
.value("IEEE", InputPrecision::IEEE)
.value("HF32", InputPrecision::HF32)
.export_values();
py::enum_<F8F6F4Type>(m, "F8F6F4TY", py::module_local())
.value("E4M3", F8F6F4Type::E4M3)
.value("E5M2", F8F6F4Type::E5M2)
.value("E2M3", F8F6F4Type::E2M3)
.value("E3M2", F8F6F4Type::E3M2)
.value("E2M1", F8F6F4Type::E2M1)
.value("BF16", F8F6F4Type::BF16)
.value("FP16", F8F6F4Type::FP16)
.export_values();
py::class_<MLIRContext>(m, "context", py::module_local())
.def(py::init<>())
.def("__enter__", [](MLIRContext &self) -> MLIRContext& { return self; },
py::return_value_policy::reference)
.def("__exit__",
[](MLIRContext &, py::object, py::object, py::object) -> bool {
return false;
})
.def("printOpOnDiagnostic",
[](MLIRContext &self, bool v) { self.printOpOnDiagnostic(v); })
.def("printStackTraceOnDiagnostic",
[](MLIRContext &self, bool v) {
self.printStackTraceOnDiagnostic(v);
})
.def("disable_multithreading",
[](MLIRContext &self) { self.disableMultithreading(); });
py::class_<SourceMgrDiagnosticHandler>(m, "source_mgr_diag",
py::module_local())
.def(py::init<llvm::SourceMgr &, MLIRContext *>());
m.def("load_dialects", [](MLIRContext &context) {
DialectRegistry registry;
registry.insert<TritonDialect, ::mlir::triton::gpu::TritonGPUDialect,
math::MathDialect, arith::ArithDialect, index::IndexDialect,
scf::SCFDialect, ::mlir::gpu::GPUDialect,
cf::ControlFlowDialect, LLVM::LLVMDialect,
mlir::ub::UBDialect>();
mlir::LLVM::registerInlinerInterface(registry);
registerBuiltinDialectTranslation(registry);
registerLLVMDialectTranslation(registry);
mlir::LLVM::registerInlinerInterface(registry);
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
});
py::class_<Type>(m, "type", py::module_local())
.def("is_integer",
[](Type &self, unsigned width) { return self.isInteger(width); })
.def("is_fp16", &Type::isF16)
.def("__str__", [](Type &self) {
std::string str;
llvm::raw_string_ostream os(str);
self.print(os);
return os.str();
});
py::class_<FunctionType>(m, "function_type", py::module_local())
.def("param_types", [](FunctionType &self) {
return std::vector<Type>(self.getInputs().begin(),
self.getInputs().end());
});
py::class_<Location>(m, "location", py::module_local())
.def("__str__", [](Location &self) {
std::string str;
llvm::raw_string_ostream os(str);
self.print(os);
return os.str();
});
py::class_<Value>(m, "value", py::module_local())
.def("set_attr",
[](Value &self, std::string &name, Attribute &attr) -> void {
if (Operation *definingOp = self.getDefiningOp())
definingOp->setAttr(name, attr);
else {
auto arg = mlir::cast<BlockArgument>(self);
int id = arg.getArgNumber();
std::string attrName = name + "_arg" + std::to_string(id);
Block *owner = arg.getOwner();
if (owner->isEntryBlock() &&
!isa<FuncOp>(owner->getParentOp())) {
owner->getParentOp()->setAttr(attrName, attr);
}
}
})
.def("get_context", &Value::getContext)
.def("replace_all_uses_with",
[](Value &self, Value &newValue) {
self.replaceAllUsesWith(newValue);
})
.def("get_type", &Value::getType)
.def("id", [](Value &self) {
return (uint64_t)self.getImpl();
});
py::class_<OpResult, Value>(m, "op_result", py::module_local());
py::class_<BlockArgument, Value>(m, "block_argument", py::module_local());
py::class_<Region>(m, "region", py::module_local())
.def("get_parent_region", &Region::getParentRegion, ret::reference)
.def("size", [](Region &self) { return self.getBlocks().size(); })
.def("empty", &Region::empty)
.def("id", [](Region &self) { return (uint64_t)&self; });
py::class_<Block>(m, "block", py::module_local())
.def("arg",
[](Block &self, int index) -> BlockArgument {
if (index >= self.getNumArguments())
throw pybind11::index_error("Block argument index out of range");
return self.getArgument(index);
})
.def("add_argument",
[](Block &self, Type ty) {
auto loc = UnknownLoc::get(ty.getContext());
self.addArgument(ty, loc);
})
.def("get_num_arguments", &Block::getNumArguments)
.def("get_argument", &Block::getArgument)
.def("dump", &Block::dump)
.def("move_before",
[](Block &self, Block &dst) { self.moveBefore(&dst); })
.def("insert_before", &Block::insertBefore)
.def("get_parent", &Block::getParent, ret::reference)
.def("merge_block_before",
[](Block &self, Block &dst) {
if (self.getNumArguments() != 0)
throw std::runtime_error(
"This block has arguments, don't merge");
dst.getOperations().splice(dst.begin(), self.getOperations());
self.dropAllUses();
self.erase();
})
.def("replace_use_in_block_with",
[](Block &self, Value &v, Value &newVal) {
v.replaceUsesWithIf(newVal, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
Block *currentBlock = user->getBlock();
while (currentBlock) {
if (currentBlock == &self)
return true;
currentBlock =
currentBlock->getParent()->getParentOp()->getBlock();
}
return false;
});
})
.def("__str__",
[](Block &self) {
std::string str;
llvm::raw_string_ostream os(str);
self.print(os);
return str;
})
.def("has_terminator",
[](Block &self) {
return !self.empty() &&
self.back().hasTrait<OpTrait::IsTerminator>();
})
.def("has_return",
[](Block &self) {
return !self.empty() &&
self.back().hasTrait<OpTrait::ReturnLike>();
})
.def("erase", [](Block &self) { self.erase(); })
.def("id", [](Block &self) { return (uint64_t)&self; });
py::class_<Attribute>(m, "attribute", py::module_local());
py::class_<IntegerAttr, Attribute>(m, "integer_attr", py::module_local());
py::class_<BoolAttr, Attribute>(m, "bool_attr", py::module_local());
py::class_<UnitAttr, Attribute>(m, "unit_attr", py::module_local());
py::class_<StringAttr, Attribute>(m, "str_attr", py::module_local());
py::class_<ArrayAttr, Attribute>(m, "array_attr", py::module_local());
py::class_<OpState>(m, "OpState", py::module_local())
.def("set_attr",
[](OpState &self, std::string &name, Attribute &attr) -> void {
self->setAttr(name, attr);
})
.def("get_num_results",
[](OpState &self) -> unsigned { return self->getNumResults(); })
.def("get_result",
[](OpState &self, unsigned idx) -> Value {
if (idx >= self->getNumResults())
throw pybind11::index_error("Op result index out of range");
return self->getResult(idx);
})
.def(
"get_region",
[](OpState &self, unsigned idx) -> Region & {
if (idx >= self->getNumRegions())
throw pybind11::index_error("Op region index out of range");
return self->getRegion(idx);
},
ret::reference)
.def(
"get_body",
[](scf::ForOp &self, unsigned idx) -> Block * {
if (idx >= self->getNumRegions())
throw pybind11::index_error("Op region index out of range");
return self.getBody(idx);
},
ret::reference)
.def("dump", [](OpState &self) { self->dump(); })
.def("__str__",
[](OpState &self) -> std::string {
std::string str;
llvm::raw_string_ostream os(str);
auto printingFlags = OpPrintingFlags();
printingFlags.enableDebugInfo();
self->print(os, printingFlags);
return str;
})
.def("append_operand",
[](OpState &self, Value &val) {
self->insertOperands(self->getNumOperands(), val);
})
.def("verify", [](OpState &self) -> bool {
return succeeded(verify(self.getOperation()));
});
py::class_<scf::ForOp, OpState>(m, "ForOp", py::module_local())
.def("get_induction_var", &scf::ForOp::getInductionVar);
py::class_<scf::IfOp, OpState>(m, "IfOp", py::module_local())
.def("get_then_block", &scf::IfOp::thenBlock, ret::reference)
.def("get_else_block", &scf::IfOp::elseBlock, ret::reference)
.def("get_then_yield", &scf::IfOp::thenYield)
.def("get_else_yield", &scf::IfOp::elseYield);
py::class_<scf::YieldOp, OpState>(m, "YieldOp", py::module_local());
py::class_<scf::WhileOp, OpState>(m, "WhileOp", py::module_local())
.def("get_before", &scf::WhileOp::getBefore, ret::reference)
.def("get_after", &scf::WhileOp::getAfter, ret::reference);
py::class_<scf::ConditionOp, OpState>(m, "ConditionOp", py::module_local());
py::class_<Operation, std::unique_ptr<Operation, py::nodelete>>(
m, "operation", py::module_local())
.def("get_name",
[](Operation &self) {
llvm::StringRef opName = self.getName().getStringRef();
return opName.str();
})
.def("get_num_operands", &Operation::getNumOperands)
.def("get_operand", &Operation::getOperand)
.def("get_num_results", &Operation::getNumResults)
.def("get_result", &Operation::getResult)
.def("get_num_regions", &Operation::getNumRegions)
.def("get_region", &Operation::getRegion, ret::reference)
.def("get_block", &Operation::getBlock, ret::reference)
.def("get_str_attr",
[](Operation &self, const std::string &name) -> py::object {
auto ret = self.getAttrOfType<StringAttr>(name);
if (!ret)
return py::none();
return py::str(ret.getValue().str());
})
.def("get_bool_attr",
[](Operation &self, const std::string &name) -> py::object {
auto ret = self.getAttrOfType<BoolAttr>(name);
if (!ret)
return py::none();
return py::bool_(ret.getValue());
})
.def("get_flat_symbol_ref_attr",
[](Operation &self, const std::string &name) -> py::object {
auto ret = self.getAttrOfType<FlatSymbolRefAttr>(name);
if (!ret)
return py::none();
return py::str(ret.getValue().str());
});
py::class_<ModuleOp, OpState>(m, "module", py::module_local(),
py::dynamic_attr())
.def("dump", &ModuleOp::dump)
.def("str",
[](ModuleOp &self) -> std::string {
std::string str;
llvm::raw_string_ostream os(str);
auto printingFlags = OpPrintingFlags();
printingFlags.enableDebugInfo();
self.print(os, printingFlags);
return str;
})
.def("push_back",
[](ModuleOp &self, FuncOp &funcOp) -> void {
self.push_back(funcOp);
})
.def("has_function",
[](ModuleOp &self, std::string &funcName) -> bool {
if (self.lookupSymbol(funcName))
return true;
return false;
})
.def("get_function",
[](ModuleOp &self, std::string &funcName) -> FuncOp {
return self.lookupSymbol<FuncOp>(funcName);
})
.def("get_int_attr",
[](ModuleOp &self, std::string name) -> py::object {
auto ret = self->getAttrOfType<IntegerAttr>(name);
if (!ret)
return py::none();
return py::int_(ret.getInt());
})
.def("create_location_snapshot",
[](ModuleOp &self, const std::string &fileName) -> void {
generateLocationsFromIR(llvm::nulls(),
fileName,
self, {});
})
.def("walk",
[](ModuleOp &self, const std::function<void(Operation *)> &fn) {
self.walk(fn);
});
m.def("make_attr", [](const std::vector<int> &values, MLIRContext &context) {
return mlir::cast<Attribute>(DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(values.size())},
IntegerType::get(&context, 32)),
values));
});
m.def(
"parse_mlir_module",
[](const std::string &inputFilename, MLIRContext &context) {
OwningOpRef<ModuleOp> module =
parseSourceFile<ModuleOp>(inputFilename, &context);
if (!module)
throw std::runtime_error("Parse MLIR file failed.");
return module->clone();
},
ret::take_ownership);
py::class_<FuncOp, OpState>(m, "function", py::module_local())
.def("args",
[](FuncOp &self, unsigned idx) -> BlockArgument {
if (idx >= self.getNumArguments())
throw pybind11::index_error(
"Function argument index out of range");
return self.getArgument(idx);
})
.def(
"add_entry_block",
[](FuncOp &self) -> Block * { return self.addEntryBlock(); },
ret::reference)
.def(
"set_arg_attr",
[](FuncOp &self, int arg_no, const std::string &name, int val) {
if (arg_no >= self.getNumArguments())
throw pybind11::index_error(
"Function argument index out of range");
auto attrTy = IntegerType::get(self.getContext(), 32);
self.setArgAttr(arg_no, name, IntegerAttr::get(attrTy, val));
},
ret::reference)
.def("finalize",
[](FuncOp &self) -> void {
self.walk([&](AdvanceOp op) {
if (op->getResult(0).use_empty())
outputWarning(op->getLoc(), "The result of tl.advance is not "
"being used. Note that tl.advance "
"does not have any side effects. "
"To move the block pointer, you "
"need to assign the result of "
"tl.advance to a variable.");
});
})
.def_property_readonly("type", &FuncOp::getFunctionType)
.def("reset_type", &FuncOp::setType);
py::class_<OpBuilder::InsertPoint>(m, "InsertPoint", py::module_local());
static py::class_<TritonOpBuilder> builderClass(
m, "builder", py::module_local(), py::dynamic_attr());
ir::builderClassPtr = &builderClass;
builderClass.def(py::init<MLIRContext *, const std::string &>(),
py::arg("context"),
py::arg("compile_mode") = "simd",
"Create a TritonOpBuilder with optional compile_mode (simt or simd, default: simd)")
.def("is_simt_mode", &TritonOpBuilder::isSimtMode,
"Check if the compile mode is simt")
.def("create_module",
[](TritonOpBuilder &self) -> ModuleOp {
return self.create<ModuleOp>();
})
.def("set_insertion_point_to_start",
[](TritonOpBuilder &self, Block &block) -> void {
self.setInsertionPointToStart(block);
})
.def("set_insertion_point_to_end",
[](TritonOpBuilder &self, Block &block) {
self.setInsertionPointToEnd(block);
})
.def("set_insertion_point_after",
[](TritonOpBuilder &self, Operation &op) {
self.setInsertionPointAfter(op);
})
.def(
"get_insertion_block",
[](TritonOpBuilder &self) -> Block * {
return self.getBuilder().getInsertionBlock();
},
ret::reference)
.def("get_insertion_point",
[](TritonOpBuilder &self) {
return self.getBuilder().saveInsertionPoint();
})
.def("restore_insertion_point",
[](TritonOpBuilder &self, OpBuilder::InsertPoint pt) {
self.restoreInsertionPoint(pt);
})
.def("get_bool_attr",
[](TritonOpBuilder &self, bool value) {
return self.getBuilder().getBoolAttr(value);
})
.def("get_int32_attr",
[](TritonOpBuilder &self, int32_t value) {
return self.getBuilder().getI32IntegerAttr(value);
})
.def("get_str_attr",
[](TritonOpBuilder &self, std::string value) {
return self.getBuilder().getStringAttr(value);
})
.def("get_unit_attr",
[](TritonOpBuilder &self) {
return self.getBuilder().getUnitAttr();
})
.def("get_i64_array_attr",
[](TritonOpBuilder &self, const std::vector<int64_t>& array) {
return self.getBuilder().getI64ArrayAttr(array);
})
.def("get_type_array_attr",
[](TritonOpBuilder &self, const std::vector<Type>& array) {
return self.getBuilder().getTypeArrayAttr(array);
})
.def("get_int1",
[](TritonOpBuilder &self, bool v) -> Value {
return Value(self.create<arith::ConstantIntOp>(
v, self.getBuilder().getI1Type()));
})
.def("get_int8",
[](TritonOpBuilder &self, int64_t v) -> Value {
return Value(self.create<arith::ConstantIntOp>(
v, self.getBuilder().getI8Type()));
})
.def("get_int16",
[](TritonOpBuilder &self, int64_t v) -> Value {
return Value(self.create<arith::ConstantIntOp>(
v, self.getBuilder().getI16Type()));
})
.def("get_int32",
[](TritonOpBuilder &self, int64_t v) -> Value {
return Value(self.create<arith::ConstantIntOp>(
v, self.getBuilder().getI32Type()));
})
.def("get_int64",
[](TritonOpBuilder &self, int64_t v) -> Value {
return Value(self.create<arith::ConstantIntOp>(
v, self.getBuilder().getI64Type()));
})
.def("get_uint8",
[](TritonOpBuilder &self, uint64_t v) -> Value {
return Value(self.create<arith::ConstantIntOp>(
v, self.getBuilder().getI8Type()));
})
.def("get_uint16",
[](TritonOpBuilder &self, uint64_t v) -> Value {
return Value(self.create<arith::ConstantIntOp>(
v, self.getBuilder().getI16Type()));
})
.def("get_uint32",
[](TritonOpBuilder &self, uint64_t v) -> Value {
return Value(self.create<arith::ConstantIntOp>(
v, self.getBuilder().getI32Type()));
})
.def("get_uint64",
[](TritonOpBuilder &self, uint64_t v) -> Value {
return Value(self.create<arith::ConstantIntOp>(
v, self.getBuilder().getI64Type()));
})
.def("get_bf16",
[](TritonOpBuilder &self, float v) -> Value {
auto type = self.getBuilder().getBF16Type();
return self.create<arith::ConstantFloatOp>(
APFloat(type.getFloatSemantics(), std::to_string(v)), type);
})
.def("get_fp16",
[](TritonOpBuilder &self, float v) -> Value {
return self.create<arith::ConstantOp>(
self.getBuilder().getF16FloatAttr(v));
})
.def("get_fp32",
[](TritonOpBuilder &self, float v) -> Value {
return self.create<arith::ConstantOp>(
self.getBuilder().getF32FloatAttr(v));
})
.def("get_fp64",
[](TritonOpBuilder &self, double v) -> Value {
return self.create<arith::ConstantOp>(
self.getBuilder().getF64FloatAttr(v));
})
.def("get_fp8e4nv",
[](TritonOpBuilder &self, double v) -> Value {
return self.create<arith::ConstantOp>(
FloatAttr::get(self.getBuilder().getFloat8E4M3FNType(), v));
})
.def("get_fp8e5",
[](TritonOpBuilder &self, double v) -> Value {
return self.create<arith::ConstantOp>(
FloatAttr::get(self.getBuilder().getFloat8E5M2Type(), v));
})
.def("get_null_value",
[](TritonOpBuilder &self, Type type) -> Value {
if (auto floatTy = dyn_cast<FloatType>(type))
return self.create<arith::ConstantFloatOp>(
APFloat(floatTy.getFloatSemantics(), 0), floatTy);
else if (auto intTy = dyn_cast<IntegerType>(type))
return self.create<arith::ConstantIntOp>(0, intTy);
else
throw std::runtime_error("Not implemented");
})
.def("get_all_ones_value",
[](TritonOpBuilder &self, Type type) -> Value {
uint64_t val = 0xFFFFFFFFFFFFFFFF;
if (auto intTy = dyn_cast<IntegerType>(type))
return self.create<arith::ConstantIntOp>(val, intTy);
else
throw std::runtime_error("Not implemented");
})
.def("get_void_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getNoneType();
})
.def("get_int1_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getI1Type();
})
.def("get_int8_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getI8Type();
})
.def("get_int16_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getType<IntegerType>(16);
})
.def("get_int32_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getI32Type();
})
.def("get_int64_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getI64Type();
})
.def("get_fp8e4nv_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getType<Float8E4M3FNType>();
})
.def("get_fp8e4b8_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getType<Float8E4M3FNUZType>();
})
.def("get_fp8e4b15_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getI8Type();
})
.def("get_fp8e5_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getType<Float8E5M2Type>();
})
.def("get_fp8e5b16_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getType<Float8E5M2FNUZType>();
})
.def("get_half_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getF16Type();
})
.def("get_bf16_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getBF16Type();
})
.def("get_float_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getF32Type();
})
.def("get_double_ty",
[](TritonOpBuilder &self) -> Type {
return self.getBuilder().getF64Type();
})
.def("get_ptr_ty",
[](TritonOpBuilder &self, Type &type, int addrSpace) -> Type {
return PointerType::get(type, addrSpace);
})
.def("get_block_ty",
[](TritonOpBuilder &self, Type &elementType,
std::vector<int64_t> &shape) -> Type {
return RankedTensorType::get(shape, elementType);
})
.def("get_buffer_ty",
[](TritonOpBuilder &self, std::vector<int64_t> &shape,
Type &elementType, const Attribute &memorySpace) -> Type {
return MemRefType::get(shape, elementType,
MemRefLayoutAttrInterface{}, memorySpace);
})
.def("get_buffer_ty_with_strides",
[](TritonOpBuilder &self, std::vector<int64_t> &shape,
Type &elementType, const std::vector<int64_t> &strides,
const Attribute &memorySpace) -> Type {
auto layout = StridedLayoutAttr::get(
self.getBuilder().getContext(), ShapedType::kDynamic, strides);
return MemRefType::get(shape, elementType, layout, memorySpace);
})
.def("get_function_ty",
[](TritonOpBuilder &self, std::vector<Type> inTypes,
std::vector<Type> outTypes) -> Type {
return self.getBuilder().getFunctionType(inTypes, outTypes);
})
.def("set_loc",
[](TritonOpBuilder &self, Location loc) { self.setLastLoc(loc); })
.def("set_loc",
[](TritonOpBuilder &self, const std::string &fileName, int line,
int column) { self.setLastLoc(fileName, line, column); })
.def("get_loc",
[](TritonOpBuilder &self) -> Location { return self.getLastLoc(); })
.def("get_or_insert_function",
[](TritonOpBuilder &self, ModuleOp &module, std::string &funcName,
Type &funcType, std::string &visibility,
bool noinline) -> FuncOp {
if (Operation *funcOperation = module.lookupSymbol(funcName))
return llvm::dyn_cast<FuncOp>(funcOperation);
if (auto funcTy = dyn_cast<FunctionType>(funcType)) {
llvm::SmallVector<NamedAttribute> attrs = {
NamedAttribute(
self.getBuilder().getStringAttr("sym_visibility"),
self.getBuilder().getStringAttr(visibility)),
NamedAttribute(self.getBuilder().getStringAttr("noinline"),
self.getBuilder().getBoolAttr(noinline))};
return self.create<FuncOp>(funcName, funcTy, attrs);
}
throw std::invalid_argument("invalid function type");
})
.def(
"create_block",
[](TritonOpBuilder &self) -> Block * {
Region *parent = self.getBuilder().getBlock()->getParent();
return self.getBuilder().createBlock(parent);
},
ret::reference)
.def(
"create_block_with_parent",
[](TritonOpBuilder &self, Region &parent,
std::vector<Type> &argTypes) -> Block * {
auto loc = self.getBuilder().getUnknownLoc();
llvm::SmallVector<Location, 8> argLocs(argTypes.size(), loc);
return self.getBuilder().createBlock(&parent, {}, argTypes,
argLocs);
},
ret::reference)
.def(
"new_block",
[](TritonOpBuilder &self) -> Block * { return new Block(); },
ret::reference)
.def("ret",
[](TritonOpBuilder &self, std::vector<Value> &vals) -> OpState {
return self.create<ReturnOp>(vals);
})
.def("call",
[](TritonOpBuilder &self, FuncOp &func, std::vector<Value> &args)
-> OpState { return self.create<CallOp>(func, args); })
.def("create_cond_branch",
[](TritonOpBuilder &self, Value condition, Block *trueDest,
Block *falseDest) -> OpState {
return self.create<cf::CondBranchOp>(condition, trueDest,
falseDest);
})
.def("create_branch",
[](TritonOpBuilder &self, Block *dest, std::vector<Value> &args)
-> OpState { return self.create<cf::BranchOp>(dest, args); })
.def("create_for_op",
[](TritonOpBuilder &self, Value &lb, Value &ub, Value &step,
std::vector<Value> &initArgs) -> scf::ForOp {
return self.create<scf::ForOp>(lb, ub, step, initArgs);
})
.def("create_if_op",
[](TritonOpBuilder &self, std::vector<Type> &retTypes,
Value &condition, bool withElse) -> scf::IfOp {
return self.create<scf::IfOp>(retTypes, condition, withElse);
})
.def("create_yield_op",
[](TritonOpBuilder &self, std::vector<Value> &yields)
-> scf::YieldOp { return self.create<scf::YieldOp>(yields); })
.def("create_while_op",
[](TritonOpBuilder &self, std::vector<Type> &retTypes,
std::vector<Value> &initArgs) -> scf::WhileOp {
return self.create<scf::WhileOp>(retTypes, initArgs);
})
.def("create_condition_op",
[](TritonOpBuilder &self, Value &cond,
std::vector<Value> &args) -> scf::ConditionOp {
return self.create<scf::ConditionOp>(cond, args);
})
.def("create_make_range",
[](TritonOpBuilder &self, int start, int end) -> Value {
auto retType = RankedTensorType::get(
{end - start}, self.getBuilder().getI32Type());
return self.create<MakeRangeOp>(retType, start, end);
})
.def("create_fp_to_fp",
[](TritonOpBuilder &self, Value &src, Type &dstType,
std::optional<RoundingMode> roundingMode) -> Value {
if (roundingMode.has_value())
return self.create<FpToFpOp>(
dstType, src,
RoundingModeAttr::get(self.getBuilder().getContext(),
roundingMode.value()));
else
return self.create<FpToFpOp>(dstType, src);
})
.def("create_bitcast",
[](TritonOpBuilder &self, Value &src, Type &dstType) -> Value {
return self.create<BitcastOp>(dstType, src);
})
.def("create_si_to_fp",
[](TritonOpBuilder &self, Value &src, Type &dstType) -> Value {
return self.create<arith::SIToFPOp>(dstType, src);
})
.def("create_ui_to_fp",
[](TritonOpBuilder &self, Value &src, Type &dstType) -> Value {
return self.create<arith::UIToFPOp>(dstType, src);
})
.def("create_fp_to_si",
[](TritonOpBuilder &self, Value &src, Type &dstType) -> Value {
return self.create<arith::FPToSIOp>(dstType, src);
})
.def("create_fp_to_ui",
[](TritonOpBuilder &self, Value &src, Type &dstType) -> Value {
return self.create<arith::FPToUIOp>(dstType, src);
})
.def("create_fp_ext",
[](TritonOpBuilder &self, Value &src, Type &dstType) -> Value {
return self.create<arith::ExtFOp>(dstType, src);
})
.def("create_fp_trunc",
[](TritonOpBuilder &self, Value &src, Type &dstType) -> Value {
return self.create<arith::TruncFOp>(dstType, src);
})
.def("create_int_cast",
[](TritonOpBuilder &self, Value &src, Type &dstType,
bool isSigned) -> Value {
Type srcType = src.getType();
auto srcTensorType = dyn_cast<RankedTensorType>(srcType);
auto dstTensorType = dyn_cast<RankedTensorType>(dstType);
Type srcEltType = srcType;
Type dstEltType = dstType;
if (dstTensorType && srcTensorType) {
dstEltType = dstTensorType.getElementType();
srcEltType = srcTensorType.getElementType();
}
unsigned srcWidth = srcEltType.getIntOrFloatBitWidth();
unsigned dstWidth = dstEltType.getIntOrFloatBitWidth();
if (srcWidth == dstWidth)
return self.create<arith::BitcastOp>(dstType, src);
else if (srcWidth > dstWidth)
return self.create<arith::TruncIOp>(dstType, src);
else if (isSigned)
return self.create<arith::ExtSIOp>(dstType, src);
else
return self.create<arith::ExtUIOp>(dstType, src);
})
.def("create_to_index",
[](TritonOpBuilder &self, Value &input) -> Value {
return self.create<arith::IndexCastOp>(
self.getBuilder().getIndexType(), input);
})
.def("create_index_to_si",
[](TritonOpBuilder &self, Value &input) -> Value {
return self.create<arith::IndexCastOp>(
self.getBuilder().getI64Type(), input);
})
.def("create_fmul",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::MulFOp>(lhs, rhs);
})
.def("create_fdiv",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::DivFOp>(lhs, rhs);
})
.def("create_frem",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::RemFOp>(lhs, rhs);
})
.def("create_fadd",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::AddFOp>(lhs, rhs);
})
.def("create_fsub",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::SubFOp>(lhs, rhs);
})
.def("create_mul",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::MulIOp>(lhs, rhs);
})
.def("create_umulhi",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<triton::MulhiUIOp>(lhs, rhs);
})
.def("create_sdiv",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::DivSIOp>(lhs, rhs);
})
.def("create_udiv",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::DivUIOp>(lhs, rhs);
})
.def("create_srem",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::RemSIOp>(lhs, rhs);
})
.def("create_urem",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::RemUIOp>(lhs, rhs);
})
.def("create_add",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::AddIOp>(lhs, rhs);
})
.def("create_sub",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::SubIOp>(lhs, rhs));
})
.def("create_fma",
[](TritonOpBuilder &self, Value &a, Value &b, Value &c) -> Value {
return Value(self.create<math::FmaOp>(a, b, c));
})
.def("create_shl",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::ShLIOp>(lhs, rhs));
})
.def("create_lshr",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::ShRUIOp>(lhs, rhs));
})
.def("create_ashr",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::ShRSIOp>(lhs, rhs));
})
.def("create_minsi",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MinSIOp>(lhs, rhs));
})
.def("create_minui",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MinUIOp>(lhs, rhs));
})
.def("create_minimumf",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MinimumFOp>(lhs, rhs));
})
.def("create_minnumf",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MinNumFOp>(lhs, rhs));
})
.def("create_maxsi",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MaxSIOp>(lhs, rhs));
})
.def("create_maxui",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MaxUIOp>(lhs, rhs));
})
.def("create_maximumf",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MaximumFOp>(lhs, rhs));
})
.def("create_maxnumf",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<arith::MaxNumFOp>(lhs, rhs));
})
.def("create_clampf",
[](TritonOpBuilder &self, Value &input, Value &min, Value &max,
PropagateNan propagateNan) -> Value {
return Value(self.create<ClampFOp>(input, min, max, propagateNan));
})
.def("create_precise_sqrt",
[](TritonOpBuilder &self, Value &input) -> Value {
return Value(self.create<PreciseSqrtOp>(input));
})
.def("create_precise_divf",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return Value(self.create<PreciseDivFOp>(lhs, rhs));
})
.def("create_addptr",
[](TritonOpBuilder &self, Value &ptr, Value &offset) -> Value {
return self.create<AddPtrOp>(ptr.getType(), ptr, offset);
})
.def("create_icmpSLE",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpIOp>(arith::CmpIPredicate::sle, lhs,
rhs);
})
.def("create_icmpSLT",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpIOp>(arith::CmpIPredicate::slt, lhs,
rhs);
})
.def("create_icmpSGE",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpIOp>(arith::CmpIPredicate::sge, lhs,
rhs);
})
.def("create_icmpSGT",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpIOp>(arith::CmpIPredicate::sgt, lhs,
rhs);
})
.def("create_icmpULE",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpIOp>(arith::CmpIPredicate::ule, lhs,
rhs);
})
.def("create_icmpULT",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpIOp>(arith::CmpIPredicate::ult, lhs,
rhs);
})
.def("create_icmpUGE",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpIOp>(arith::CmpIPredicate::uge, lhs,
rhs);
})
.def("create_icmpUGT",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpIOp>(arith::CmpIPredicate::ugt, lhs,
rhs);
})
.def("create_icmpEQ",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpIOp>(arith::CmpIPredicate::eq, lhs,
rhs);
})
.def("create_icmpNE",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpIOp>(arith::CmpIPredicate::ne, lhs,
rhs);
})
.def("create_fcmpOLT",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, lhs,
rhs);
})
.def("create_fcmpOGT",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, lhs,
rhs);
})
.def("create_fcmpOLE",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::OLE, lhs,
rhs);
})
.def("create_fcmpOGE",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::OGE, lhs,
rhs);
})
.def("create_fcmpOEQ",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::OEQ, lhs,
rhs);
})
.def("create_fcmpONE",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::ONE, lhs,
rhs);
})
.def("create_fcmpULT",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::ULT, lhs,
rhs);
})
.def("create_fcmpUGT",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::UGT, lhs,
rhs);
})
.def("create_fcmpULE",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::ULE, lhs,
rhs);
})
.def("create_fcmpUGE",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::UGE, lhs,
rhs);
})
.def("create_fcmpUEQ",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::UEQ, lhs,
rhs);
})
.def("create_fcmpUNE",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::CmpFOp>(arith::CmpFPredicate::UNE, lhs,
rhs);
})
.def("create_and",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::AndIOp>(lhs, rhs);
})
.def("create_xor",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::XOrIOp>(lhs, rhs);
})
.def("create_or",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
return self.create<arith::OrIOp>(lhs, rhs);
})
.def("create_load",
[](TritonOpBuilder &self, Value &ptrs, CacheModifier cacheModifier,
EvictionPolicy evictionPolicy, bool isVolatile) -> Value {
return self.create<LoadOp>(ptrs, cacheModifier, evictionPolicy,
isVolatile);
})
.def("create_store",
[](TritonOpBuilder &self, Value &ptrs, Value &value,
CacheModifier cacheModifier,
EvictionPolicy evictionPolicy) -> void {
self.create<StoreOp>(ptrs, value, cacheModifier, evictionPolicy);
})
.def("create_tensor_pointer_load",
[](TritonOpBuilder &self, Value &ptr,
std::vector<int32_t> &boundaryCheck,
std::optional<PaddingOption> paddingOption,
CacheModifier cacheModifier, EvictionPolicy evictionPolicy,
bool isVolatile) -> Value {
return self.create<LoadOp>(ptr, boundaryCheck, paddingOption,
cacheModifier, evictionPolicy,
isVolatile);
})
.def("create_tensor_pointer_store",
[](TritonOpBuilder &self, Value &ptr, Value &val,
std::vector<int32_t> &boundaryCheck, CacheModifier cacheModifier,
EvictionPolicy evictionPolicy) -> void {
self.create<StoreOp>(ptr, val, boundaryCheck, cacheModifier,
evictionPolicy);
})
.def("create_masked_load",
[](TritonOpBuilder &self, Value &ptrs, Value &mask,
std::optional<Value> &other, CacheModifier cacheModifier,
EvictionPolicy evictionPolicy, bool isVolatile) -> Value {
return self.create<LoadOp>(ptrs, mask, other.value_or(Value()),
cacheModifier, evictionPolicy,
isVolatile);
})
.def("create_masked_store",
[](TritonOpBuilder &self, Value &ptrs, Value &val, Value &mask,
CacheModifier cacheModifier,
EvictionPolicy evictionPolicy) -> void {
self.create<StoreOp>(ptrs, val, mask, cacheModifier,
evictionPolicy);
})
.def("create_tensor_descriptor_type",
[](TritonOpBuilder &self, Type blockTy, bool isSigned) -> Type {
auto ctx = self.getBuilder().getContext();
return triton::TensorDescType::get(ctx, cast<RankedTensorType>(blockTy), isSigned);
})
.def("create_descriptor_load",
[](TritonOpBuilder &self, Value desc, std::vector<Value> &indices, CacheModifier cacheModifier,
EvictionPolicy evictionPolicy) -> Value {
auto descTy = cast<triton::TensorDescType>(desc.getType());
auto resTy = descTy.getSignlessBlockType();
return self.create<DescriptorLoadOp>(resTy, desc, indices, cacheModifier, evictionPolicy);
})
.def("create_descriptor_store",
[](TritonOpBuilder &self, Value desc, Value value, std::vector<Value> &indices) -> void {
self.create<DescriptorStoreOp>(desc, value, indices);
})
.def("create_make_tensor_descriptor",
[](TritonOpBuilder &self, Value &base, std::vector<Value> &shape, std::vector<Value> &strides,
std::vector<int32_t> &tensorShape, bool isSignedInteger) -> Value {
return self.create<MakeTensorDescOp>(base, shape, strides, tensorShape, isSignedInteger);
})
.def("create_tensormap_create",
[](TritonOpBuilder &self, Value desc_ptr, Value global_address,
std::vector<Value> box_dim, std::vector<Value> global_dim,
std::vector<Value> global_stride,
std::vector<Value> element_stride, int32_t elem_type,
int32_t interleave_layout, int32_t swizzle_mode,
int32_t fill_mode) {
self.create<ExperimentalTensormapCreateOp>(
desc_ptr, global_address, box_dim, global_dim, global_stride,
element_stride, elem_type, interleave_layout, swizzle_mode,
fill_mode);
})
.def("create_tensormap_fenceproxy_acquire",
[](TritonOpBuilder &self, Value desc_ptr) {
self.create<ExperimentalTensormapFenceproxyAcquireOp>(desc_ptr);
})
.def("create_reshape",
[](TritonOpBuilder &self, Value &arg, std::vector<int64_t> &shape,
bool allowReorder) -> Value {
auto argType =
cast<RankedTensorType>(arg.getType()).getElementType();
return self.create<ReshapeOp>(
RankedTensorType::get(shape, argType), arg, allowReorder);
})
.def("create_expand_dims",
[](TritonOpBuilder &self, Value &arg, int axis) -> Value {
auto argType = dyn_cast<RankedTensorType>(arg.getType());
auto argEltType = argType.getElementType();
std::vector<int64_t> retShape = argType.getShape();
retShape.insert(retShape.begin() + axis, 1);
return self.create<ExpandDimsOp>(
RankedTensorType::get(retShape, argEltType), arg, axis);
})
.def("create_cat",
[](TritonOpBuilder &self, Value &lhs, Value &rhs) -> Value {
auto lhsType = dyn_cast<RankedTensorType>(lhs.getType());
auto rhsType = dyn_cast<RankedTensorType>(rhs.getType());
if (!(lhsType.getShape().size() == 1 &&
rhsType.getShape().size() == 1))
throw std::invalid_argument(
"shape not supported by cat. Expecting rank-1 inputs");
std::vector<int64_t> shape{lhsType.getShape()[0] +
rhsType.getShape()[0]};
return self.create<CatOp>(
RankedTensorType::get(shape, lhsType.getElementType()), lhs,
rhs);
})
.def("create_join",
[](TritonOpBuilder &self, Value &a, Value &b) -> Value {
return self.create<JoinOp>(a, b);
})
.def("create_split",
[](TritonOpBuilder &self, Value &a) -> std::vector<Value> {
auto op = self.create<SplitOp>(a);
return std::vector<Value>(op->result_begin(), op->result_end());
})
.def("create_trans",
[](TritonOpBuilder &self, Value &arg,
std::vector<int> &order) -> Value {
auto argType = dyn_cast<RankedTensorType>(arg.getType());
auto argEltType = argType.getElementType();
auto retShape = applyPermutation(argType.getShape(), order);
return self.create<TransOp>(
RankedTensorType::get(retShape, argEltType), arg, order);
})
.def("create_broadcast",
[](TritonOpBuilder &self, Value &arg,
std::vector<int64_t> &shape) -> Value {
if (auto argType = dyn_cast<RankedTensorType>(arg.getType()))
return self.createOrFold<BroadcastOp>(
RankedTensorType::get(shape, argType.getElementType()), arg);
throw std::invalid_argument(
"arg is not of RankedTensorType, use create_splat");
})
.def("create_splat",
[](TritonOpBuilder &self, Value &arg,
std::vector<int64_t> &shape) -> Value {
auto argType = arg.getType();
auto ret = self.createOrFold<SplatOp>(
RankedTensorType::get(shape, argType), arg);
return ret;
})
.def("create_atomic_cas",
[](TritonOpBuilder &self, Value &ptr, Value &cmp, Value &val,
MemSemantic sem, MemSyncScope scope) -> Value {
Type dstType;
if (auto srcTensorType =
dyn_cast<RankedTensorType>(ptr.getType())) {
Type dstElemType =
cast<PointerType>(srcTensorType.getElementType())
.getPointeeType();
dstType =
RankedTensorType::get(srcTensorType.getShape(), dstElemType);
} else {
auto ptrType = cast<PointerType>(getElementTypeOrSelf(ptr));
dstType = ptrType.getPointeeType();
}
return self.create<AtomicCASOp>(dstType, ptr, cmp, val, sem,
scope);
})
.def("create_atomic_rmw",
[](TritonOpBuilder &self, RMWOp rmwOp, Value &ptr, Value &val,
Value &mask, MemSemantic sem, MemSyncScope scope) -> Value {
Type dstType;
if (auto srcTensorType =
dyn_cast<RankedTensorType>(ptr.getType())) {
Type dstElemType =
cast<PointerType>(srcTensorType.getElementType())
.getPointeeType();
dstType =
RankedTensorType::get(srcTensorType.getShape(), dstElemType);
} else {
auto ptrType = cast<PointerType>(getElementTypeOrSelf(ptr));
dstType = ptrType.getPointeeType();
}
return self.create<AtomicRMWOp>(dstType, rmwOp, ptr, val, mask,
sem, scope);
})
.def("create_extern_elementwise",
[](TritonOpBuilder &self, const std::string &libName,
const std::string &libPath, const std::string &symbol,
std::vector<Value> &argList, Type retType, bool isPure) -> Value {
return self.create<ExternElementwiseOp>(retType, argList, libName,
libPath, symbol, isPure);
})
.def("create_get_program_id",
[](TritonOpBuilder &self, int axis) -> Value {
if (axis < 0 || axis > 3)
throw pybind11::index_error("program_id must be in [0,3]");
return self.create<GetProgramIdOp>(axis);
})
.def("create_get_num_programs",
[](TritonOpBuilder &self, int axis) -> Value {
if (axis < 0 || axis > 3)
throw pybind11::index_error("program_id must be in [0,3]");
return self.create<GetNumProgramsOp>(axis);
})
.def("create_dot",
[](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b,
mlir::Value &c, InputPrecision inputPrecision,
int maxNumImpreciseAcc) -> mlir::Value {
return self.create<DotOp>(c.getType(), a, b, c, inputPrecision,
maxNumImpreciseAcc);
})
.def("create_dot_scaled",
[](TritonOpBuilder &self, mlir::Value &lhs, mlir::Value &lhs_scale,
F8F6F4Type lhs_format, mlir::Value &rhs,
std::optional<mlir::Value> &rhs_scale, F8F6F4Type rhs_format,
mlir::Value &c) -> mlir::Value {
return self.create<DotScaledOp>(
c.getType(), lhs, rhs, c, lhs_scale,
rhs_scale.value_or(Value()), lhs_format, rhs_format);
})
.def("create_floor",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::FloorOp>(val);
})
.def("create_ceil",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::CeilOp>(val);
})
.def("create_exp",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::ExpOp>(val);
})
.def("create_exp2",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::Exp2Op>(val);
})
.def("create_cos",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::CosOp>(val);
})
.def("create_sin",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::SinOp>(val);
})
.def("create_log",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::LogOp>(val);
})
.def("create_log2",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::Log2Op>(val);
})
.def("create_erf",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::ErfOp>(val);
})
.def("create_sqrt",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::SqrtOp>(val);
})
.def("create_rsqrt",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::RsqrtOp>(val);
})
.def("create_fabs",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::AbsFOp>(val);
})
.def("create_iabs",
[](TritonOpBuilder &self, Value &val) -> Value {
return self.create<math::AbsIOp>(val);
})
.def("create_reduce",
[](TritonOpBuilder &self, std::vector<Value> operands, int axis)
-> OpState { return self.create<ReduceOp>(operands, axis); })
.def("create_reduce_ret",
[](TritonOpBuilder &self, py::args args) -> OpState {
llvm::SmallVector<Value> return_values;
for (const auto &arg : args) {
return_values.push_back(py::cast<Value>(arg));
}
return self.create<ReduceReturnOp>(return_values);
})
.def("create_scan",
[](TritonOpBuilder &self, std::vector<Value> operands, int axis,
bool reverse) -> OpState {
return self.create<ScanOp>(operands, axis, reverse);
})
.def("create_scan_ret",
[](TritonOpBuilder &self, py::args args) -> OpState {
llvm::SmallVector<Value> return_values;
for (const auto &arg : args) {
return_values.push_back(py::cast<Value>(arg));
}
return self.create<ScanReturnOp>(return_values);
})
.def("create_ptr_to_int",
[](TritonOpBuilder &self, Value &val, Type &type) -> Value {
return self.create<PtrToIntOp>(type, val);
})
.def("create_int_to_ptr",
[](TritonOpBuilder &self, Value &val, Type &type) -> Value {
return self.create<IntToPtrOp>(type, val);
})
.def("create_select",
[](TritonOpBuilder &self, Value &condition, Value &trueValue,
Value &falseValue) -> Value {
return self.create<arith::SelectOp>(condition, trueValue,
falseValue);
})
.def("create_inline_asm",
[](TritonOpBuilder &self, const std::string &inlineAsm,
const std::string &constraints, const std::vector<Value> &values,
const std::vector<Type> &types, bool isPure,
int pack) -> OpState {
return self.create<ElementwiseInlineAsmOp>(
types, inlineAsm, constraints, isPure, pack, values);
})
.def("create_print",
[](TritonOpBuilder &self, const std::string &prefix, bool hex,
const std::vector<Value> &values,
const std::vector<int32_t> &isSigned) -> void {
auto prefixAttr = StringAttr::get(self.getBuilder().getContext(),
llvm::StringRef(prefix));
self.create<PrintOp>(prefixAttr, hex, values, isSigned);
})
.def("create_assert",
[](TritonOpBuilder &self, Value &condition,
const std::string &message) -> void {
auto messageAttr = StringAttr::get(self.getBuilder().getContext(),
llvm::StringRef(message));
self.create<AssertOp>(condition, messageAttr);
})
.def("create_assume",
[](TritonOpBuilder &self, Value &condition) {
self.create<LLVM::AssumeOp>(condition);
})
.def("create_poison",
[](TritonOpBuilder &self, Type &type) -> Value {
return self.create<ub::PoisonOp>(type);
})
.def("create_histogram",
[](TritonOpBuilder &self, Value operand, int numBins) -> Value {
return self.create<HistogramOp>(
RankedTensorType::get(
{static_cast<int64_t>(numBins)},
IntegerType::get(operand.getContext(), 32)),
operand);
})
.def("create_gather",
[](TritonOpBuilder &self, Value src, Value indices, int axis)
-> Value { return self.create<GatherOp>(src, indices, axis); })
.def("create_barrier",
[](TritonOpBuilder &self) { self.create<mlir::gpu::BarrierOp>(); })
.def("create_make_block_ptr",
[](TritonOpBuilder &self, Value &base, std::vector<Value> &shape,
std::vector<Value> &strides, std::vector<Value> &offsets,
std::vector<int32_t> &tensorShape,
std::vector<int32_t> &order) -> Value {
return self.create<MakeTensorPtrOp>(base, shape, strides, offsets,
tensorShape, order);
})
.def("create_advance",
[](TritonOpBuilder &self, Value &ptr,
std::vector<Value> &offsets) -> Value {
return self.create<AdvanceOp>(ptr.getType(), ptr, offsets);
});
py::class_<PassManager>(m, "pass_manager", py::module_local())
.def(py::init<MLIRContext *>())
.def("enable_debug",
[](PassManager &self) {
auto *context = self.getContext();
bool haveDiagnostics =
::triton::tools::getBoolEnv("MLIR_ENABLE_DIAGNOSTICS");
bool haveDump = ::triton::tools::getBoolEnv("MLIR_ENABLE_DUMP");
std::string funcToDump;
if (!haveDump) {
funcToDump = triton::tools::getStrEnv("MLIR_ENABLE_DUMP");
if (!funcToDump.empty())
haveDump = true;
}
if (haveDiagnostics || haveDump) {
context->disableMultithreading();
}
if (haveDiagnostics) {
context->printOpOnDiagnostic(true);
context->printStackTraceOnDiagnostic(true);
context->getDiagEngine().registerHandler([](Diagnostic &diag) {
llvm::outs() << diag << "\n";
return success();
});
}
if (haveDump) {
auto printingFlags = OpPrintingFlags();
printingFlags.elideLargeElementsAttrs(16);
printingFlags.enableDebugInfo();
auto printAlways = [funcToDump](Pass *, Operation *op) -> bool {
if (funcToDump.empty())
return true;
if (auto mod = dyn_cast<mlir::ModuleOp>(op)) {
return mod.lookupSymbol(funcToDump);
}
if (auto func = dyn_cast<triton::FuncOp>(op)) {
return SymbolTable::getSymbolName(func).getValue() ==
funcToDump;
}
return false;
};
self.enableIRPrinting(
printAlways,
printAlways,
true,
false,
true, llvm::dbgs(),
printingFlags);
}
})
.def("run", [](PassManager &self, ModuleOp &mod) {
auto *context = mod.getContext();
if (::triton::tools::getBoolEnv("MLIR_DISABLE_MULTITHREADING"))
context->disableMultithreading();
auto reproducerPath =
triton::tools::getStrEnv("TRITON_REPRODUCER_PATH");
if (!reproducerPath.empty()) {
auto anchorName = self.getOpAnchorName();
auto passes = self.getPasses();
Operation *op = mod.getOperation();
makeReproducer(anchorName, passes, op, reproducerPath);
context->disableMultithreading();
}
if (triton::tools::getBoolEnv("TRITON_ENABLE_LLVM_DEBUG")) {
::llvm::DebugFlag = true;
}
if (auto debugOnly = triton::tools::getStrEnv("TRITON_LLVM_DEBUG_ONLY");
!debugOnly.empty()) {
llvm::SmallVector<StringRef, 3> split;
llvm::SmallVector<std::string, 3> storage;
llvm::SmallVector<const char *, 3> debugTypes;
StringRef(debugOnly.c_str()).split(split, ',');
llvm::transform(split, std::back_inserter(debugTypes),
[&storage](StringRef str) {
storage.push_back(str.str());
return storage.back().c_str();
});
::llvm::DebugFlag = true;
using namespace llvm;
setCurrentDebugTypes(debugTypes.data(), debugTypes.size());
}
bool haveTiming = ::triton::tools::getBoolEnv("MLIR_ENABLE_TIMING");
if (haveTiming) {
self.enableTiming();
}
if (failed(self.run(mod.getOperation())))
throw std::runtime_error("PassManager::run failed");
},
py::call_guard<py::gil_scoped_release>());
}
void init_triton_env_vars(py::module &m) {
m.def("get_cache_invalidating_env_vars",
[]() -> std::map<std::string, std::string> {
std::map<std::string, std::string> ret;
for (const auto &envVar : CACHE_INVALIDATING_ENV_VARS) {
auto strVal = triton::tools::getStrEnv(envVar);
if (strVal.empty())
continue;
auto boolV = triton::tools::isEnvValueBool(strVal);
if (boolV.has_value())
ret[envVar] = boolV.value() ? "true" : "false";
else
ret[envVar] = strVal;
}
return ret;
});
}