* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#pragma once
#include <memory>
#include "mlir/IR/Builders.h"
#include "triton/Tools/Sys/GetEnv.hpp"
using namespace mlir;
class TritonOpBuilder {
public:
TritonOpBuilder(MLIRContext *context, const std::string &compile_mode = "simd") {
builder = std::make_unique<OpBuilder>(context);
lastLoc = std::make_unique<Location>(builder->getUnknownLoc());
this->compile_mode = compile_mode;
}
OpBuilder &getBuilder() { return *builder; }
bool isLineInfoEnabled() { return lineInfoEnabled; }
bool isSimtMode() const { return compile_mode == "simt"; }
void setLastLoc(Location loc) {
if (lineInfoEnabled)
lastLoc = std::make_unique<Location>(loc);
}
void setLastLoc(const std::string &fileName, int line, int column) {
auto context = builder->getContext();
setLastLoc(FileLineColLoc::get(context, fileName, line, column));
}
Location getLastLoc() {
assert(lastLoc);
return *lastLoc;
}
void setInsertionPointToStart(Block &block) {
if (!block.empty())
setLastLoc(block.begin()->getLoc());
else
setLastLoc(builder->getUnknownLoc());
builder->setInsertionPointToStart(&block);
}
void setInsertionPointToEnd(Block &block) {
if (!block.empty())
setLastLoc(block.back().getLoc());
else
setLastLoc(builder->getUnknownLoc());
builder->setInsertionPointToEnd(&block);
}
void setInsertionPointAfter(Operation &op) {
setLastLoc(op.getLoc());
builder->setInsertionPointAfter(&op);
}
void restoreInsertionPoint(OpBuilder::InsertPoint pt) {
if (pt.isSet() && pt.getPoint() != pt.getBlock()->end())
setLastLoc(pt.getPoint()->getLoc());
else
setLastLoc(builder->getUnknownLoc());
builder->restoreInsertionPoint(pt);
}
template <typename OpTy, typename... Args> OpTy create(Args &&...args) {
auto loc = getLastLoc();
return builder->create<OpTy>(loc, std::forward<Args>(args)...);
}
template <typename OpTy, typename... Args>
std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(), Value>
createOrFold(Args &&...args) {
auto loc = getLastLoc();
return builder->createOrFold<OpTy>(loc, std::forward<Args>(args)...);
}
template <typename OpTy, typename... Args>
std::enable_if_t<OpTy::template hasTrait<OpTrait::ZeroResults>(), OpTy>
createOrFold(Args &&...args) {
auto loc = getLastLoc();
return builder->createOrFold<OpTy>(loc, std::forward<Args>(args)...);
}
private:
std::unique_ptr<OpBuilder> builder;
std::unique_ptr<Location> lastLoc;
bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO");
std::string compile_mode;
};