* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* Copyright (c) Microsoft Corporation.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "ascend/include/TritonToLinalg/TritonOpConverter.h"
#include "ascend/include/TritonToLinalg/TritonToLinalgPass.h"
#include "ascend/include/TritonToLinalg/BlockPtrAnalysis.h"
#include "ascend/include/TritonToLinalg/MaskAnalysis.h"
#include "ascend/include/Utils/Utils.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/raw_ostream.h"
#include <algorithm>
#include <utility>
#include <cstdlib>
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Interfaces/CallInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/ValueRange.h"
#include "bishengir/Dialect/Annotation/IR/Annotation.h"
#include "bishengir/Dialect/HFusion/IR/HFusion.h"
#include "bishengir/Dialect/HIVM/IR/HIVM.h"
namespace TTOpConverters {
using namespace mlir;
using namespace triton;
static const llvm::SmallVector<llvm::StringRef> libdeviceOps = {
"__hmf_pow_fp32",
"__hmf_div_rz_fp32",
"__hmf_fmod_fp32",
"__hmf_float_as_int_fp32",
"__hmf_trunc_fp32", "__hmf_trunc_fp16",
"__hmf_nearbyint_fp32",
"__hmf_signbit_fp32", "__hmf_signbit_fp16",
"__hmf_copysign_fp32",
"__hmf_log10_fp32",
"__hmf_tanh_fp32",
"__hmf_asin_fp32", "__hmf_asin_fp16",
"__hmf_acos_fp32", "__hmf_acos_fp16",
"__hmf_atan2_fp32", "__hmf_atan2_fp16",
"__hmf_sinh_fp32", "__hmf_sinh_fp16",
"__hmf_cosh_fp32", "__hmf_cosh_fp16",
"__hmf_asinh_fp32", "__hmf_asinh_fp16",
"__hmf_acosh_fp32", "__hmf_acosh_fp16",
"__hmf_atanh_fp32", "__hmf_atanh_fp16",
"__hmf_expm1_fp32", "__hmf_expm1_fp16",
"__hmf_nextafter_fp32", "__hmf_nextafter_fp16",
"__hmf_hypot_fp32", "__hmf_hypot_fp16",
"__hmf_cyl_bessel_i0_fp32", "__hmf_cyl_bessel_i0_fp16",
"__hmf_erfinv_fp32",
"__hmf_lgamma_fp32",
};
* Retrieves a boolean environment variable.
* @param envVar The name of the environment variable.
* @param defaultValue The default value to return if the variable is not set or cannot be parsed.
* @return true if the environment variable exists and its value is parsed as "true", otherwise returns defaultValue.
* Parsing rules (case-insensitive): "true" values: any non-empty string not equal to "0", "false", "no", "off" is considered true.
* "false" values: an empty string or a string equal to any of the false literals is considered false.
*/
bool getEnvBool(const char* envVar, bool defaultValue)
{
const char* val = std::getenv(envVar);
if (val == nullptr) {
return defaultValue;
}
std::string s(val);
std::transform(s.begin(), s.end(), s.begin(),
[](unsigned char c) { return std::tolower(c); });
if (s.empty() || s == "0" || s == "false" || s == "no" || s == "off") {
return false;
}
return true;
}
static llvm::SmallString<kFuncNameCap> generateUniqueFuncName(
ModuleOp moduleOp, llvm::StringRef funcNameBase)
{
llvm::SmallString<kFuncNameCap> funcName = funcNameBase;
int uniqueId = 0;
while (SymbolTable::lookupSymbolIn(moduleOp, funcName)) {
funcName = funcNameBase;
funcName += ("_" + std::to_string(uniqueId++));
}
return funcName;
}
LogicalResult
BitcastConverter::matchAndRewrite(triton::BitcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value result;
auto loc = op.getLoc();
if (auto dstPtrTy = dyn_cast<triton::PointerType>(op.getType())) {
auto srcPtrTy = cast<triton::PointerType>(op.getSrc().getType());
auto resType = MemRefType::get({ShapedType::kDynamic}, dstPtrTy.getPointeeType());
auto i1Ty = rewriter.getIntegerType(1);
auto i8Ty = rewriter.getIntegerType(8);
bool isI1toI8 = (srcPtrTy.getPointeeType() == i1Ty) &&
(dstPtrTy.getPointeeType() == i8Ty);
if (isI1toI8) {
LLVM_DEBUG({
llvm::dbgs()
<< "[BitcastConverter] Special i1->i8 pointer bitcast. Forward "
"without arith.bitcast. srcConvertedTy="
<< adaptor.getSrc().getType() << "\n";
});
rewriter.replaceOp(op, adaptor.getSrc());
return success();
}
result = rewriter.create<arith::BitcastOp>(
loc, resType, adaptor.getSrc());
} else {
result = rewriter.create<arith::BitcastOp>(
loc, op.getType(), adaptor.getSrc());
}
rewriter.replaceOp(op, result);
return success();
}
LogicalResult
TransposeConverter::matchAndRewrite(triton::TransOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto src = adaptor.getSrc();
auto res = ConverterUtils::getTransposedValue(src, op.getLoc(), rewriter,
op.getOrder());
rewriter.replaceOp(op, res);
return success();
}
LogicalResult
YieldConverter::matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
return success();
}
LogicalResult
AdvanceConverter::matchAndRewrite(triton::AdvanceOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
llvm::SmallDenseMap<Value, BlockData> known;
BlockDataParser::rewriteAdvanceOp(op, rewriter, known);
return success();
}
LogicalResult MakeTensorPtrConverter::matchAndRewrite(
triton::MakeTensorPtrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
llvm::SmallDenseMap<Value, BlockData> known;
BlockDataParser::rewriteMakeTensorPtrOp(op, adaptor.getBase(), rewriter, known);
return success();
}
LogicalResult PreciseDivConverter::matchAndRewrite(
triton::PreciseDivFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value opa = op.getX();
Value opb = op.getY();
auto loc = op.getLoc();
auto resType = dyn_cast<RankedTensorType>(op.getResult().getType());
auto divOp = rewriter.create<arith::DivFOp>(loc, resType, opa, opb);
rewriter.replaceOp(op, divOp);
return success();
}
LogicalResult SelectCanonicalizer::matchAndRewrite(
arith::SelectOp op, PatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto type = dyn_cast<TensorType>(op.getResult().getType());
if (!type) {
return failure();
}
auto elementType = type.getElementType();
if (elementType.isInteger(1)) {
return failure();
}
auto tensorShape = type.getShape();
auto mask = op.getCondition();
if (!isa<ShapedType>(mask.getType())) {
return failure();
}
MaskState mstate;
auto isContMask = mstate.parse(mask, loc, rewriter);
if (isContMask.failed()) {
mstate.eraseInsertedOps(op, rewriter);
return rewriter.notifyMatchFailure(
op, "Cannot lower continuous masked selects");
}
MaskPosition maskPos = mstate.getMaskPosition(tensorShape);
LLVM_DEBUG({
llvm::dbgs()
<< "[SelectAnalysis] MaskPosition detected: "
<< (maskPos == MaskPosition::Head ? "Head" :
maskPos == MaskPosition::Tail ? "Tail" :
maskPos == MaskPosition::Middle ? "Middle" : "Unknown") << "\n";
});
if (maskPos == MaskPosition::Unknown) {
mstate.eraseInsertedOps(op, rewriter);
return failure();
}
auto trueTensor = op.getTrueValue();
auto falseTensor = op.getFalseValue();
if (maskPos == MaskPosition::Head) {
auto extractSliceOp = mstate.getExtractSlice(trueTensor, loc, rewriter);
auto insertSliceOp =
mstate.getInsertSlice(extractSliceOp, falseTensor, loc, rewriter);
LLVM_DEBUG({
llvm::dbgs()
<< " -> Created ExtractSlice: "
<< *extractSliceOp.getOperation() << "\n"
<< " -> Created InsertSlice: "
<< *insertSliceOp.getOperation() << "\n";
});
rewriter.replaceOp(op, insertSliceOp);
return success();
}
SmallVector<OpFoldResult> invertOffsets;
SmallVector<OpFoldResult> invertFalseDims;
SmallVector<OpFoldResult> invertTrueDims;
OpFoldResult falseDimOp;
OpFoldResult trueDimOp;
int valDim = -1;
for (int i = 0; i< mstate.getRank(); ++i) {
const auto &offVal = mstate.offsets[i];
const auto &dimVal = mstate.dims[i];
auto constOffVal = getConstantIntValue(offVal);
invertOffsets.push_back(rewriter.getIndexAttr(0));
if (constOffVal.has_value() && constOffVal.value() == 0) {
invertFalseDims.push_back(dimVal);
invertTrueDims.push_back(dimVal);
} else {
assert(valDim == -1 && "The offset in only one dimension can be not zero.");
if (!constOffVal.has_value()) {
valDim = i;
falseDimOp = offVal;
}
invertFalseDims.push_back(offVal);
trueDimOp = addOpFoldResult(offVal, dimVal, loc, rewriter);
invertTrueDims.push_back(trueDimOp);
}
}
auto falseExtractSliceOp = mstate.getExtractSlice(falseTensor, loc, rewriter,
invertOffsets, invertFalseDims);
auto trueInsertSliceOp = mstate.getInsertSlice(falseExtractSliceOp, trueTensor, loc, rewriter,
invertOffsets, invertFalseDims);
auto extractSliceOp = mstate.getExtractSlice(trueInsertSliceOp, loc, rewriter,
invertOffsets, invertTrueDims);
auto insertSliceOp = mstate.getInsertSlice(extractSliceOp, falseTensor, loc, rewriter,
invertOffsets, invertTrueDims);
if (valDim != -1) {
rewriter.setInsertionPointAfter(trueInsertSliceOp);
assert(isa<Value>(falseDimOp) && "Expected to be a runtime Value for dynamic dimension check.");
Value zeroIndex = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value isNegative = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
cast<Value>(falseDimOp), zeroIndex);
Value sizeIndex = rewriter.create<arith::ConstantIndexOp>(loc, tensorShape[valDim]);
Value isOutOfRange = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge,
cast<Value>(falseDimOp), sizeIndex);
auto orOp = rewriter.create<arith::OrIOp>(loc, isNegative, isOutOfRange);
auto ifOp = rewriter.create<scf::IfOp>(loc, TypeRange{type}, orOp.getResult(), true, true);
Block *thenBlock = &ifOp.getThenRegion().front();
rewriter.setInsertionPointToStart(thenBlock);
rewriter.create<scf::YieldOp>(loc, ValueRange{falseTensor});
Block *elseBlock = &ifOp.getElseRegion().front();
rewriter.setInsertionPointToStart(elseBlock);
falseExtractSliceOp->moveBefore(elseBlock, elseBlock->begin());
trueInsertSliceOp->moveAfter(falseExtractSliceOp);
extractSliceOp->moveAfter(trueInsertSliceOp);
insertSliceOp->moveAfter(extractSliceOp);
rewriter.setInsertionPointAfter(insertSliceOp);
rewriter.create<scf::YieldOp>(loc, ValueRange{insertSliceOp.getResult()});
rewriter.replaceOp(op, ifOp);
} else {
rewriter.replaceOp(op, insertSliceOp);
}
LLVM_DEBUG({
llvm::dbgs()
<< " -> [invert] Created false tensor extractSlice: "
<< *falseExtractSliceOp.getOperation() << "\n"
<< " -> [invert] Created true tensor insertSlice: "
<< *trueInsertSliceOp.getOperation() << "\n"
<< " -> [invert] Created ExtractSlice: "
<< *extractSliceOp.getOperation() << "\n"
<< " -> [invert] Created InsertSlice: "
<< *insertSliceOp.getOperation() << "\n";
});
return success();
}
* Move tt.bitcast to a previous location if tt.bitcast is not directly applied
* on function arguments
*/
LogicalResult
BitcastCanonicalizer::matchAndRewrite(triton::BitcastOp bitcastOp,
PatternRewriter &rewriter) const {
Value castSrc = bitcastOp.getSrc();
Value castRes = bitcastOp.getResult();
Type castSrcTy = castSrc.getType();
Type castSrcPtrTy = isa<ShapedType>(castSrcTy)
? cast<ShapedType>(castSrcTy).getElementType()
: castSrcTy;
if (!isa<triton::PointerType>(castSrcPtrTy))
return failure();
auto origBitwidth = getPointeeBitWidth(castSrc.getType());
auto castBitwidth = getPointeeBitWidth(castRes.getType());
if (origBitwidth == 1)
origBitwidth = 8;
if (castBitwidth == 1)
castBitwidth = 8;
if (origBitwidth != castBitwidth) {
bitcastOp.emitError() << "Casting pointers with unmatched bitwidth!\n";
return failure();
}
Operation *beforeCastOp = castSrc.getDefiningOp();
if (beforeCastOp == nullptr) {
return failure();
}
auto newRes =
TypeSwitch<Operation *, FailureOr<Operation *>>(beforeCastOp)
.Case<triton::AddPtrOp>([&](triton::AddPtrOp addptrOp) {
auto newCastOp = rewriter.create<triton::BitcastOp>(
bitcastOp.getLoc(), castRes.getType(), addptrOp.getPtr());
return rewriter.create<triton::AddPtrOp>(
bitcastOp.getLoc(), castRes.getType(), newCastOp.getResult(),
addptrOp.getOffset());
})
.Case<triton::SplatOp>([&](triton::SplatOp splatOp) {
Type newCastSrcTy =
cast<RankedTensorType>(castRes.getType()).getElementType();
Value splatSrc = splatOp.getSrc();
Type splatSrcTy = splatSrc.getType();
if (auto splatSrcTensorTy = dyn_cast<RankedTensorType>(splatSrcTy))
newCastSrcTy =
splatSrcTensorTy.cloneWith(std::nullopt, newCastSrcTy);
auto newCastOp = rewriter.create<triton::BitcastOp>(
bitcastOp.getLoc(), newCastSrcTy, splatSrc);
return rewriter.create<triton::SplatOp>(
bitcastOp.getLoc(), castRes.getType(), newCastOp);
})
.Case<triton::BitcastOp>([&](triton::BitcastOp prevCastOp) {
return rewriter.create<triton::BitcastOp>(
bitcastOp.getLoc(), castRes.getType(), prevCastOp.getSrc());
})
.Default([&](Operation *op) {
return rewriter.notifyMatchFailure(bitcastOp,
"Unknown bitcast pattern");
});
if (succeeded(newRes)) {
rewriter.replaceOp(bitcastOp, newRes.value());
if (beforeCastOp->use_empty()) {
rewriter.eraseOp(beforeCastOp);
}
return success();
}
return failure();
}
LogicalResult FpToFpCanonicalizer::matchAndRewrite(
triton::FpToFpOp op, PatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value input = op.getSrc();
auto resultType = op.getResult().getType();
auto roundingMode = op.getRounding();
if (roundingMode.has_value() && roundingMode.value() != triton::RoundingMode::RTNE) {
return failure();
}
auto srcType = cast<RankedTensorType>(input.getType());
auto dstType = cast<RankedTensorType>(resultType);
auto srcElemType = srcType.getElementType();
auto dstElemType = dstType.getElementType();
if (!isa<FloatType>(srcElemType) || !isa<FloatType>(dstElemType)) {
return op.emitError("FpToFp expects floating point types");
}
unsigned srcBitwidth = srcElemType.getIntOrFloatBitWidth();
unsigned dstBitwidth = dstElemType.getIntOrFloatBitWidth();
auto roundModeAttr = hfusion::RoundModeAttr::get(
rewriter.getContext(), hfusion::RoundMode::RINT);
if (srcBitwidth > dstBitwidth) {
auto truncOp = rewriter.create<arith::TruncFOp>(loc, dstType, input);
truncOp->setAttr("round_mode", roundModeAttr);
rewriter.replaceOp(op, truncOp.getResult());
} else if (srcBitwidth < dstBitwidth) {
auto extOp = rewriter.create<arith::ExtFOp>(loc, dstType, input);
extOp->setAttr("round_mode", roundModeAttr);
rewriter.replaceOp(op, extOp.getResult());
} else {
rewriter.replaceOp(op, input);
}
return success();
}
void rewriteUserWithNewOrder(mlir::OpOperand *use, PatternRewriter &rewriter, llvm::SmallVector<int64_t, 8> &blkShapeI64,
mlir::Location &loc, llvm::ArrayRef<int32_t> &order, size_t &orderSize)
{
Operation *user = use->getOwner();
rewriter.setInsertionPointAfter(user);
if (auto loadOp = dyn_cast<triton::LoadOp>(user)) {
auto loadResTy = loadOp.getResult().getType();
auto loadResShapedTy = cast<ShapedType>(loadResTy);
auto newLoadTy = loadResShapedTy.cloneWith(
blkShapeI64, loadResShapedTy.getElementType());
auto newLoadOp = rewriter.create<triton::LoadOp>(
loc, newLoadTy, loadOp->getOperands(), loadOp->getAttrs());
newLoadOp->setAttr(ConverterUtils::GeneratedByMakeTensorPtrTAG, UnitAttr::get(rewriter.getContext()));
rewriter.replaceOp(loadOp, newLoadOp);
SmallVector<int32_t, 8> permuteOrder;
for (auto [i, v] : llvm::enumerate(order)) {
permuteOrder.push_back(orderSize - 1 - order[i]);
}
auto permuteOp = rewriter.create<triton::TransOp>(
loc, newLoadOp.getResult(),
DenseI32ArrayAttr::get(loadOp.getContext(), permuteOrder));
newLoadOp.getResult().replaceAllUsesExcept(permuteOp.getResult(), permuteOp);
} else if (auto storeOp = dyn_cast<triton::StoreOp>(user)) {
SmallVector<int32_t, 8> permuteOrder;
for (auto [i, v] : llvm::enumerate(order)) {
permuteOrder.push_back(order[orderSize - 1 - i]);
}
auto permuteOp = rewriter.create<triton::TransOp>(
loc, storeOp.getValue(),
DenseI32ArrayAttr::get(storeOp.getContext(), permuteOrder));
storeOp.getValue().replaceAllUsesExcept(permuteOp.getResult(), permuteOp);
auto newStoreOp = rewriter.create<triton::StoreOp>(
loc, storeOp.getPtr(), storeOp.getValue(), storeOp.getMask(),
storeOp.getBoundaryCheck(), storeOp.getCache(), storeOp.getEvict());
rewriter.replaceOp(storeOp, newStoreOp);
} else if (auto advanceOp = dyn_cast<triton::AdvanceOp>(user)) {
auto advanceResPtrTy =
cast<triton::PointerType>(advanceOp.getResult().getType());
auto advanceResShapedTy =
cast<ShapedType>(advanceResPtrTy.getPointeeType());
auto newAdvanceResShapedTy = advanceResShapedTy.cloneWith(
blkShapeI64, advanceResShapedTy.getElementType());
auto newAdvanceResPtrTy = triton::PointerType::get(
newAdvanceResShapedTy, advanceResPtrTy.getAddressSpace());
auto advanceOffsets = advanceOp.getOffsets();
llvm::SmallVector<Value, 8> newAdvanceOffsets;
for (int i = orderSize - 1; i >= 0; i--) {
newAdvanceOffsets.push_back(advanceOffsets[order[i]]);
}
SmallVector<OpOperand *> resUses;
for (auto &use: advanceOp->getUses())
resUses.push_back(&use);
auto newAdvanceOp = rewriter.create<triton::AdvanceOp>(
loc, newAdvanceResPtrTy, advanceOp.getPtr(), newAdvanceOffsets);
rewriter.replaceOp(advanceOp, newAdvanceOp);
for (auto resUse : resUses)
rewriteUserWithNewOrder(resUse, rewriter, blkShapeI64, loc, order, orderSize);
} else if (auto loopOp = dyn_cast<LoopLikeOpInterface>(user)) {
auto initArg = use->get();
auto iterArg = loopOp.getTiedLoopRegionIterArg(use);
auto resultValue = loopOp.getTiedLoopResult(use);
iterArg.setType(initArg.getType());
resultValue.setType(initArg.getType());
for (auto &argUse : iterArg.getUses())
rewriteUserWithNewOrder(&argUse, rewriter, blkShapeI64, loc, order, orderSize);
for (auto &resUse : resultValue.getUses())
rewriteUserWithNewOrder(&resUse, rewriter, blkShapeI64, loc, order, orderSize);
} else if (isa<scf::YieldOp>(user)) {
return;
} else {
llvm_unreachable("[MakeTensorPtrCanonicalizer] tt.make_tensor_ptr's result is "
"not used by load/store/advance op");
}
}
void markLoadUsers(mlir::OpOperand *use, PatternRewriter &rewriter)
{
Operation *user = use->getOwner();
if (auto loadOp = dyn_cast<triton::LoadOp>(user)) {
loadOp->setAttr(ConverterUtils::GeneratedByMakeTensorPtrTAG, UnitAttr::get(rewriter.getContext()));
} else if (auto storeOp = dyn_cast<triton::StoreOp>(user)) {
return;
} else if (auto advanceOp = dyn_cast<triton::AdvanceOp>(user)) {
SmallVector<OpOperand *> resUses;
for (auto &use: advanceOp->getUses())
resUses.push_back(&use);
for (auto resUse : resUses)
markLoadUsers(resUse, rewriter);
} else if (auto loopOp = dyn_cast<LoopLikeOpInterface>(user)) {
auto initArg = use->get();
auto iterArg = loopOp.getTiedLoopRegionIterArg(use);
auto resultValue = loopOp.getTiedLoopResult(use);
iterArg.setType(initArg.getType());
resultValue.setType(initArg.getType());
for (auto &argUse : iterArg.getUses())
markLoadUsers(&argUse, rewriter);
for (auto &resUse : resultValue.getUses())
markLoadUsers(&resUse, rewriter);
} else if (isa<scf::YieldOp>(user)) {
return;
} else {
llvm_unreachable("[MakeTensorPtrCanonicalizer] tt.make_tensor_ptr's result is "
"not used by load/store/advance op");
}
}
LogicalResult
MakeTensorPtrCanonicalizer::matchAndRewrite(triton::MakeTensorPtrOp op,
PatternRewriter &rewriter) const {
auto order = op.getOrder();
auto orderSize = order.size();
if (orderSize == 1) {
return rewriter.notifyMatchFailure(
op, "make_tensor_ptr's order has single value.");
}
bool isPermuted = false;
for (auto [first, second] : llvm::zip(order.slice(0, orderSize - 1),
order.slice(1, orderSize - 1))) {
if (first != second + 1) {
isPermuted = true;
break;
}
}
auto loc = op.getLoc();
auto base = op.getBase();
auto shape = op.getShape();
auto strides = op.getStrides();
auto offsets = op.getOffsets();
auto result = op.getResult();
SmallVector<OpOperand *> opUses;
for (auto &use: result.getUses())
opUses.push_back(&use);
for (auto use : opUses)
markLoadUsers(use, rewriter);
if (!isPermuted) {
return rewriter.notifyMatchFailure(
op, "make_tensor_ptr's order is contiguous.");
}
llvm::SmallVector<int32_t, 8> blkShapeI32;
llvm::SmallVector<int64_t, 8> blkShapeI64;
auto resPtrType = cast<triton::PointerType>(result.getType());
if (auto resShapedTy = dyn_cast<ShapedType>(resPtrType.getPointeeType())) {
auto resBlkShape = resShapedTy.getShape();
for (auto [i, v] : llvm::enumerate(resBlkShape)) {
auto reverseI = orderSize - 1 - i;
blkShapeI32.push_back(resBlkShape[order[reverseI]]);
blkShapeI64.push_back(resBlkShape[order[reverseI]]);
}
}
llvm::SmallVector<Value, 8> newShape;
llvm::SmallVector<Value, 8> newStrides;
llvm::SmallVector<Value, 8> newOffsets;
for (int i = orderSize - 1; i >= 0; i--) {
newShape.push_back(shape[order[i]]);
newStrides.push_back(strides[order[i]]);
newOffsets.push_back(offsets[order[i]]);
}
llvm::SmallVector<int, 8> contiguousOrder;
for (int i = orderSize - 1; i >= 0; i--)
contiguousOrder.push_back(i);
rewriter.setInsertionPoint(op);
auto newMakeTensorPtrOp = rewriter.create<triton::MakeTensorPtrOp>(
loc, base, ValueRange(newShape), ValueRange(newStrides),
ValueRange(newOffsets), blkShapeI32, contiguousOrder);
rewriter.replaceOp(op, newMakeTensorPtrOp);
for (auto use : opUses)
rewriteUserWithNewOrder(use, rewriter, blkShapeI64, loc, order, orderSize);
return success();
}
LogicalResult ReduceSingleCanonicalizer::matchAndRewrite(triton::ReduceOp reduceOp, PatternRewriter &rewriter) const
{
assert(reduceOp.getSrcs().size() <=2 && "Only reduce or reduce with index are supported");
auto src = reduceOp.getSrcs()[0];
auto srcType = cast<RankedTensorType>(src.getType());
auto srcShape = srcType.getShape();
if (llvm::any_of(srcShape, [](auto s) { return s != 1; }))
return rewriter.notifyMatchFailure(reduceOp, "reduce's srcs are not all with single element");
auto loc = reduceOp->getLoc();
auto res = reduceOp.getResult()[0];
Value extracted;
if (srcType.getRank() == 1) {
auto zero = rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
extracted = rewriter.create<tensor::ExtractOp>(loc, src, zero.getResult()).getResult();
} else {
auto resShape = cast<RankedTensorType>(res.getType()).getShape();
auto collapseReassociationIndicesOptional = getReassociationIndicesForCollapse(srcShape, resShape);
if (!collapseReassociationIndicesOptional.has_value()) {
return rewriter.notifyMatchFailure(reduceOp, "Failure with getReassociationIndicesForCollapse call");
}
auto collapseReassociationIndices = collapseReassociationIndicesOptional.value();
extracted = rewriter.create<tensor::CollapseShapeOp>(loc, src, collapseReassociationIndices).getResult();
}
res.replaceAllUsesWith(extracted);
if(reduceOp.getSrcs().size() == 1)
return success();
auto resIdx = reduceOp.getResult()[1];
auto zeroI32 = rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
if (srcType.getRank() == 1) {
resIdx.replaceAllUsesWith(zeroI32);
} else {
auto resIdxShape = cast<RankedTensorType>(resIdx.getType()).getShape();
auto initTensor = rewriter.create<tensor::EmptyOp>(loc, resIdxShape, rewriter.getI32Type());
auto fillOp = rewriter.create<linalg::FillOp>(loc, ValueRange{zeroI32}, ValueRange{initTensor});
resIdx.replaceAllUsesWith(fillOp.getResult(0));
}
return success();
}
LogicalResult DenseConstantConverter::matchAndRewrite(
arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto denseAttr = cast<DenseElementsAttr>(op.getValue());
auto loc = op.getLoc();
auto constSplatOp = arith::ConstantOp::materialize(
rewriter, denseAttr.getSplatValue<Attribute>(),
denseAttr.getElementType(), loc);
auto emptyOp = rewriter.create<tensor::EmptyOp>(
loc, cast<RankedTensorType>(op.getResult().getType()).getShape(),
denseAttr.getElementType());
rewriter.replaceOpWithNewOp<linalg::FillOp>(op, ValueRange{constSplatOp},
ValueRange{emptyOp});
return success();
}
LogicalResult
MakeRangeConverter::matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto type = cast<TensorType>(op.getResult().getType());
auto shape = type.getShape();
auto elementType = type.getElementType();
auto context = op.getContext();
assert(type.getShape().size() == 1 &&
isa<IntegerType>(type.getElementType()) &&
type.getElementType().getIntOrFloatBitWidth() == 32 &&
"make range can only return 1D int32 tensor");
SmallVector<AffineMap> indexingMaps{AffineMap::get(
1, 0,
{mlir::getAffineDimExpr(0, context)}, context)};
auto init = rewriter.create<tensor::EmptyOp>(loc, shape, elementType);
auto nestedBody = [&](OpBuilder &nestedBuilder, Location nestedLoc,
ValueRange blockArgs) {
Value index = nestedBuilder.create<linalg::IndexOp>(loc, 0);
Value res = nestedBuilder.create<arith::IndexCastOp>(
loc, elementType, index);
nestedBuilder.create<linalg::YieldOp>(loc, res);
};
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, op->getResultTypes(), ValueRange{}, ValueRange{init},
indexingMaps, ConverterUtils::getNParallelLoopsAttrs(1), nestedBody);
linalgOp->setAttr("tt.from_make_range", mlir::UnitAttr::get(context));
linalgOp->setAttr("tt.make_range_offset",
mlir::IntegerAttr::get(mlir::IndexType::get(context), 0));
linalgOp->setAttr("tt.make_range_size",
mlir::IntegerAttr::get(mlir::IndexType::get(context), shape[0]));
int32_t startVal = op.getStartAttr().getInt();
if (startVal == 0) {
rewriter.replaceOp(op, linalgOp->getResults());
return success();
}
Value startScaler = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(static_cast<int32_t>(startVal)));
auto startInit = rewriter.create<tensor::EmptyOp>(loc, shape, elementType);
Value startTensor = rewriter.create<linalg::FillOp>(
loc, ValueRange{startScaler}, ValueRange{startInit}).getResult(0);
auto addOp = rewriter.create<arith::AddIOp>(loc, linalgOp->getResult(0),
startTensor);
rewriter.replaceOp(op, addOp);
return success();
}
LogicalResult
SplatConverter::matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto shape = op.getType().getShape();
auto init = rewriter.create<tensor::EmptyOp>(loc, shape,
op.getType().getElementType());
if (llvm::all_of(shape, [](int64_t dim) { return dim == 1; })) {
SmallVector<Value> idx(shape.size(), rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0)));
rewriter.replaceOpWithNewOp<tensor::InsertOp>(op, adaptor.getSrc(), init, idx);
} else {
rewriter.replaceOpWithNewOp<linalg::FillOp>(op, ValueRange{adaptor.getSrc()},
ValueRange{init});
}
return success();
}
LogicalResult
UnsplatConverter::matchAndRewrite(triton::UnsplatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto src = adaptor.getSrc();
auto srcType = cast<RankedTensorType>(src.getType());
auto shape = srcType.getShape();
SmallVector<Value> indices;
for (int64_t dim : shape) {
indices.push_back(rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(0)));
}
auto elementType = srcType.getElementType();
auto extractOp = rewriter.create<tensor::ExtractOp>(loc, elementType, src, indices);
rewriter.replaceOp(op, extractOp.getResult());
return success();
}
LogicalResult
ReshapeConverter::matchAndRewrite(triton::ReshapeOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto src = op.getSrc();
auto dst = op.getResult();
Value shape = rewriter.create<arith::ConstantOp>(
loc,
rewriter.getI64TensorAttr(cast<ShapedType>(dst.getType()).getShape()));
auto reshapeOp =
rewriter.create<tensor::ReshapeOp>(loc, dst.getType(), src, shape);
rewriter.replaceOp(op, reshapeOp.getResult());
return success();
}
LogicalResult ExpandDimsConverter::matchAndRewrite(
triton::ExpandDimsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto src = op.getSrc();
auto resShape = cast<ShapedType>(op.getResult().getType()).getShape();
auto axis = op.getAxis();
SmallVector<ReassociationIndices> reassociation;
auto src_last_dim = resShape.size() - 2;
auto map_func = [&](unsigned i) -> ReassociationIndices {
if (i < axis) {
return i == src_last_dim ? ReassociationIndices{i, i + 1}
: ReassociationIndices{i};
}
return i == axis ? ReassociationIndices{i, i + 1}
: ReassociationIndices{i + 1};
};
reassociation = llvm::to_vector(
llvm::map_range(llvm::seq<unsigned>(0, src_last_dim + 1), map_func));
auto expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
op.getLoc(), op.getResult().getType(), src, reassociation);
rewriter.replaceOp(op, expandShapeOp.getResult());
return success();
}
LogicalResult
ClampFConverter::matchAndRewrite(triton::ClampFOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto input = adaptor.getX();
auto min_para = adaptor.getMin();
auto max_para = adaptor.getMax();
auto propagateNan_para = adaptor.getPropagateNan();
if (auto input_type = dyn_cast<RankedTensorType>(input.getType())) {
if (isa<FloatType>(min_para.getType())) {
auto minEmptyTensor = rewriter.create<tensor::EmptyOp>(
loc, input_type.getShape(), input_type.getElementType());
min_para = rewriter
.create<linalg::FillOp>(loc, ValueRange{min_para},
ValueRange{minEmptyTensor})
.result();
}
if (isa<FloatType>(max_para.getType())) {
auto maxEmptyTensor = rewriter.create<tensor::EmptyOp>(
loc, input_type.getShape(), input_type.getElementType());
max_para = rewriter
.create<linalg::FillOp>(loc, ValueRange{max_para},
ValueRange{maxEmptyTensor})
.result();
}
}
if (propagateNan_para == PropagateNan::NONE) {
auto minOp = rewriter.create<arith::MinNumFOp>(loc, input, max_para);
auto maxOp = rewriter.create<arith::MaxNumFOp>(loc, min_para, minOp);
rewriter.replaceOp(op, ValueRange{maxOp});
} else if (propagateNan_para == PropagateNan::ALL) {
auto minOp = rewriter.create<arith::MinimumFOp>(loc, input, max_para);
auto maxOp = rewriter.create<arith::MaximumFOp>(loc, min_para, minOp);
rewriter.replaceOp(op, ValueRange{maxOp});
} else {
return failure();
}
return success();
}
LogicalResult
BroadcastConverter::matchAndRewrite(triton::BroadcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
assert(op->getNumResults() == 1 && "BroadcastOp assumes single result");
RankedTensorType sourceType =
cast<RankedTensorType>(adaptor.getSrc().getType());
RankedTensorType resultType = cast<RankedTensorType>(op.getType());
auto elementType = resultType.getElementType();
auto loc = op.getLoc();
auto initEmpty =
rewriter.create<tensor::EmptyOp>(loc, resultType.getShape(), elementType);
SmallVector<int64_t> broadcastDims =
ConverterUtils::getBroadcastDims(sourceType, resultType);
SmallVector<int64_t> unbroadcastDims =
ConverterUtils::getUnbroadcastDims(sourceType, resultType);
SmallVector<ReassociationIndices> collapseReassociationIndices;
auto collapseReassociationIndicesOptional =
getReassociationIndicesForCollapse(sourceType.getShape(),
unbroadcastDims);
if (!collapseReassociationIndicesOptional.has_value()) {
return rewriter.notifyMatchFailure(
op, "Failure with getReassociationIndicesForCollapse call");
}
collapseReassociationIndices = collapseReassociationIndicesOptional.value();
RankedTensorType collapseResultType =
RankedTensorType::get(unbroadcastDims, sourceType.getElementType());
auto collpasedOp = rewriter.create<tensor::CollapseShapeOp>(
loc, collapseResultType, adaptor.getSrc(), collapseReassociationIndices);
auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
loc, collpasedOp, initEmpty,
rewriter.getDenseI64ArrayAttr(broadcastDims));
rewriter.replaceOp(op, broadcastOp.getResults());
return success();
}
bool ReduceConverter::isReductionOpSupported(Operation *redOp) const {
return isa<arith::AddFOp, arith::AddIOp, arith::MulFOp, arith::MulIOp,
arith::MaximumFOp, arith::MaxNumFOp, arith::MinimumFOp, arith::MinNumFOp,
arith::MinSIOp, arith::MinUIOp, arith::MaxSIOp, arith::MaxUIOp,
arith::AndIOp, arith::OrIOp, arith::XOrIOp>(redOp);
}
bool ReduceConverter::isMultiReductionOpSupported(Operation *redOp)
{
return isa<arith::SubFOp, arith::SubIOp, arith::DivFOp, arith::DivSIOp, arith::DivUIOp,
arith::RemFOp, arith::RemSIOp, arith::RemUIOp>(redOp);
}
Value ReduceConverter::cloneReduceOps(OpBuilder &builder, Value in, Value out,
Value opIns, Value opOuts, triton::ReduceOp op) const
{
auto ® = op->getRegion(0);
assert(reg.getBlocks().size() == 1);
auto &body = reg.getBlocks().front();
auto numArguments = 2;
assert(body.getNumArguments() == numArguments);
Value ttIn = body.getArgument(0);
Value ttOut = body.getArgument(1);
IRMapping mapping;
mapping.map(ttIn, in);
mapping.map(ttOut, out);
for (auto &op : body.without_terminator()) {
builder.clone(op, mapping);
}
auto yield = cast<triton::ReduceReturnOp>(body.getTerminator());
return mapping.lookup(yield->getOperand(0));
}
void ReduceConverter::checkIsNotCallOp(
const llvm::SmallVector<Operation*>& reductionOps) const
{
llvm::for_each(reductionOps, [](Operation* op) {
assert(!isa<triton::CallOp>(op) &&
"tt.call ops expected to be inlined in tt.reduce body in ttir building stage");
});
}
bool ReduceConverter::isSCFOpReduce(
const llvm::SmallVector<Operation*>& reductionOps) const
{
return (reductionOps.size() == 1 && reductionOps.front()->getDialect()->getNamespace() == scf::SCFDialect::getDialectNamespace());
}
bool ReduceConverter::isMultiOpReduce(
const llvm::SmallVector<Operation*>& reductionOps) const
{
this->checkIsNotCallOp(reductionOps);
return (reductionOps.size() > 1) ||
(reductionOps.size() == 1 && this->isMultiReductionOpSupported(reductionOps.front())) ||
this->isSCFOpReduce(reductionOps);
}
Value ReduceConverter::computeReduceResultWithCompileFlag(OpBuilder &opBuilder, Location loc, Value lhs, Value rhs,
Value source, Value initTensor, triton::ReduceOp op, bool compileOn91095Flag) const
{
auto originalReductionOps = this->getReductionOps(op);
auto realReductionOps = this->getRealReductionOps(op);
bool needClone = compileOn91095Flag || originalReductionOps.size() > 1;
if (needClone) {
return this->cloneReduceOps(opBuilder, lhs, rhs, source, initTensor, op);
} else {
assert(realReductionOps.size() == 1);
auto rop = realReductionOps.front();
return this->getReductionElement(lhs, rhs, loc, rop, opBuilder, false);
}
}
LogicalResult ReduceConverter::convertToTargetOp(
triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto source = adaptor.getOperands().front();
auto sourceType = cast<RankedTensorType>(source.getType());
auto elemType = sourceType.getElementType();
auto resType = op.getResult().front().getType();
auto loc = op.getLoc();
auto realReductionOps = this->getRealReductionOps(op);
bool multiOpReduce = this->isMultiOpReduce(realReductionOps);
if (!multiOpReduce && !this->isReductionOpSupported(realReductionOps.front())) {
if (compileOn91095Flag) {
llvm_unreachable("All reduction cases expected to be covered");
}
return rewriter.notifyMatchFailure(
op, "Only support lowering reduction with single op and limited types of reduction");
}
auto rop = realReductionOps.front();
auto ropLoc = rop->getLoc();
auto axis = op.getAxis();
auto isVectorReduce = sourceType.getRank() == 1;
auto constantType = elemType;
auto accBaseConstOp = multiOpReduce ?
this->getMultiOpReductionBaseConstOp(rewriter, op, ropLoc, constantType) :
this->getReductionBaseConstOp(rewriter, rop, constantType);
Value initTensor;
if (isVectorReduce) {
auto holder = rewriter.create<bufferization::AllocTensorOp>(
loc, RankedTensorType::get({}, constantType), ValueRange{});
initTensor = rewriter
.create<linalg::FillOp>(loc, accBaseConstOp.getResult(),
holder.getResult())
.getResult(0);
} else {
Value init = rewriter.create<tensor::EmptyOp>(
loc, cast<RankedTensorType>(resType).getShape(), constantType);
initTensor =
rewriter.create<linalg::FillOp>(loc, accBaseConstOp.getResult(), init)
.getResult(0);
}
Value finalResult = rewriter.create<linalg::ReduceOp>(
loc, ValueRange{source}, ValueRange{initTensor},
SmallVector<int64_t>{axis},
[&](OpBuilder &opBuilder, Location loc, ValueRange inputs) {
assert(inputs.size() == 2);
Value result = this->computeReduceResultWithCompileFlag(opBuilder, loc,
inputs[0], inputs[1],
source, initTensor, op, compileOn91095Flag);
opBuilder.create<linalg::YieldOp>(loc, result);
})
.getResult(0);
if (sourceType.getRank() == 1) {
finalResult = rewriter.create<tensor::ExtractOp>(loc, constantType, finalResult);
}
rewriter.replaceOp(op, finalResult);
return success();
}
LogicalResult ReduceConverter::convertToTargetOpExtended(
triton::ReduceOp op, typename triton::ReduceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto elemTypes = op.getElementTypes();
auto valueResultType = dyn_cast<RankedTensorType>(op.getType(0));
const auto isScalarReduce = valueResultType == nullptr;
SmallVector<Value> outputs;
for (auto i = 0; i < op.getResult().size() && i < elemTypes.size(); i++) {
auto result = dyn_cast<RankedTensorType>(op.getType(i));
SmallVector<int64_t> resultShape{
isScalarReduce ? SmallVector<int64_t>{}
: SmallVector<int64_t>(result.getShape())};
outputs.push_back(
rewriter.create<tensor::EmptyOp>(loc, resultShape, elemTypes[i]));
}
auto linalgOp = rewriter.create<linalg::ReduceOp>(
loc, adaptor.getOperands(), outputs,
SmallVector<int64_t>{adaptor.getAxis()},
[&](OpBuilder &b, Location loc, ValueRange inputs) {
auto tritonReduceBlock = op.getBody();
IRMapping mapping;
mapping.map(tritonReduceBlock->getArguments(), inputs);
for (auto &op : tritonReduceBlock->without_terminator()) {
b.clone(op, mapping);
}
auto tritonYield = tritonReduceBlock->getTerminator();
auto results =
llvm::map_to_vector(tritonYield->getOperands(),
[&](Value val) { return mapping.lookup(val); });
b.create<linalg::YieldOp>(loc, results);
});
auto params = getReduceWithIndexParams(op);
if (failed(params)) {
return rewriter.notifyMatchFailure(op, "meaningless reduce operation");
} else if (params->withIndexType != ReduceWithIndexType::None) {
addReduceWithIndexAttr(*params, rewriter, linalgOp);
}
if (isScalarReduce) {
SmallVector<Value> reduceResults;
for (auto i = 0; i < linalgOp.getResults().size() && i < elemTypes.size();
i++) {
reduceResults.push_back(rewriter.create<tensor::ExtractOp>(
loc, elemTypes[i], linalgOp.getResults()[i], ValueRange{}));
}
rewriter.replaceOp(op, reduceResults);
} else {
rewriter.replaceOp(op, linalgOp);
}
return success();
}
bool ScanConverter::isReductionOpSupported(Operation *reductionOp) const
{
return isa<arith::AddFOp, arith::AddIOp, arith::MulFOp, arith::MulIOp>(reductionOp);
}
LogicalResult ScanConverter::convertToTargetOp(
triton::ScanOp op, typename triton::ScanOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto reductionOps = this->getReductionOps(op);
if (reductionOps.empty()) {
return rewriter.notifyMatchFailure(op, "No reduction op found in scan body");
}
llvm::SmallString<64> funcName;
auto rop = reductionOps.front();
if (this->isReductionOpSupported(reductionOps.front())) {
if (isa<arith::AddFOp, arith::AddIOp>(rop)) {
funcName = "triton_cumsum";
} else if (isa<arith::MulFOp, arith::MulIOp>(rop)) {
funcName = "triton_cumprod";
}
auto moduleOp = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPoint(moduleOp.getBody(),
std::prev(moduleOp.getBody()->end()));
auto loc = op.getLoc();
auto src = adaptor.getOperands().front();
auto resTy = op.getResult().front().getType();
auto libFnType = rewriter.getFunctionType(
{src.getType(), rewriter.getI32Type(), rewriter.getI1Type()}, {resTy});
auto funcOp = rewriter.create<func::FuncOp>(loc, funcName.str(), libFnType);
SymbolTable symTab(moduleOp);
auto maybePrintFuncNameAttr = symTab.renameToUnique(funcOp, {&symTab});
if (failed(maybePrintFuncNameAttr)) {
return op->emitError(
"failed to create a unique func name for device_print");
}
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
rewriter.setInsertionPoint(op);
auto scanAxis = op.getAxis();
auto scanReverse = op.getReverse();
Value axis = rewriter.create<arith::ConstantIntOp>(loc, scanAxis, 32);
Value reverseVal = rewriter.create<arith::ConstantIntOp>(loc, scanReverse, 1);
auto callOp = rewriter.create<func::CallOp>(loc, funcOp.getSymNameAttr(),
TypeRange({resTy}),
ValueRange({src, axis, reverseVal}));
rewriter.replaceOp(op, callOp);
return success();
} else {
bool reverse = op.getReverse();
auto loc = op.getLoc();
Value scanInput = op.getOperand(0);
auto srcType = mlir::dyn_cast<RankedTensorType>(scanInput.getType());
if (!srcType) {
return rewriter.notifyMatchFailure(op, "Expected RankedTensorType input for associative_scan");
}
auto elementType = srcType.getElementType();
auto shape = srcType.getShape();
int rank = shape.size();
int axis = op.getAxis();
if (axis < 0 || axis >= rank) {
return rewriter.notifyMatchFailure(op, "Invalid scan axis: " + std::to_string(axis));
}
if (op->getNumRegions() < 1 || op->getRegion(0).empty()) {
return rewriter.notifyMatchFailure(op, "Missing combine region");
}
OpBuilder::InsertionGuard guard(rewriter);
auto memrefType = MemRefType::get(shape, elementType);
Value inputMemRef = rewriter.create<bufferization::ToBufferOp>(loc, memrefType, scanInput);
Value outputMemRef = rewriter.create<memref::AllocOp>(loc, memrefType);
auto processDimension = [&](ArrayRef<Value> baseIdxsArray) {
auto startInd = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
if (reverse) {
startInd = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), shape[axis] - 1);
}
llvm::SmallVector<Value> baseIdxs(baseIdxsArray.begin(), baseIdxsArray.end());
llvm::SmallVector<Value> firstIdx = baseIdxs;
if (axis <= firstIdx.size()) {
firstIdx.insert(firstIdx.begin() + axis, startInd);
} else {
firstIdx.push_back(startInd);
}
Value firstVal = rewriter.create<memref::LoadOp>(loc, inputMemRef, firstIdx);
rewriter.create<memref::StoreOp>(loc, firstVal, outputMemRef, firstIdx);
Value axisSize = rewriter.create<memref::DimOp>(loc, inputMemRef, axis).getResult();
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, axisSize, one);
auto ifOp = rewriter.create<scf::IfOp>(loc, cmp, false);
rewriter.setInsertionPointToStart(ifOp.thenBlock());
auto forOp = rewriter.create<scf::ForOp>(loc, one, axisSize, one);
rewriter.setInsertionPointToStart(forOp.getBody());
Value k = forOp.getInductionVar();
if (reverse) {
llvm::SmallVector<Value> fixInd;
fixInd.push_back(rewriter.create<arith::ConstantIndexOp>(op.getLoc(), shape[axis] - 1).getResult());
fixInd.push_back(k);
auto fixIndVal = rewriter.create<arith::SubIOp>(op.getLoc(), fixInd);
k = fixIndVal.getResult();
}
llvm::SmallVector<Value> currIdx = baseIdxs;
if (axis <= currIdx.size()) {
currIdx.insert(currIdx.begin() + axis, k);
} else {
currIdx.push_back(k);
}
Value km1 = rewriter.create<arith::SubIOp>(loc, k, one);
if (reverse) {
km1 = rewriter.create<arith::AddIOp>(loc, k, one);
}
llvm::SmallVector<Value> prevIdx = baseIdxs;
if (axis <= prevIdx.size()) {
prevIdx.insert(prevIdx.begin() + axis, km1);
} else {
prevIdx.push_back(km1);
}
Value currentVal = rewriter.create<memref::LoadOp>(loc, inputMemRef, currIdx);
Value prevResult = rewriter.create<memref::LoadOp>(loc, outputMemRef, prevIdx);
Region &combineRegion = op->getRegion(0);
Block &combineBlock = combineRegion.front();
IRMapping mapping;
mapping.map(combineBlock.getArgument(0), prevResult);
mapping.map(combineBlock.getArgument(1), currentVal);
for (Operation &innerOp : combineBlock.without_terminator()) {
rewriter.clone(innerOp, mapping);
}
Operation *yieldOp = combineBlock.getTerminator();
Value resultVal = mapping.lookup(yieldOp->getOperand(0));
rewriter.create<memref::StoreOp>(loc, resultVal, outputMemRef, currIdx);
rewriter.setInsertionPointAfter(ifOp);
};
llvm::SmallVector<int> nonScanDims;
for (int i = 0; i < rank; ++i) {
if (i != axis) nonScanDims.push_back(i);
}
createSimpleNestedLoops(rewriter, loc, outputMemRef, nonScanDims, processDimension);
rewriter.setInsertionPointAfter(op);
mlir::Type resultType = mlir::memref::getTensorTypeFromMemRefType(dyn_cast<mlir::MemRefType>(outputMemRef.getType()));
Value outputTensor = rewriter.create<bufferization::ToTensorOp>(loc, resultType, outputMemRef, true);
rewriter.replaceOp(op, outputTensor);
return success();
}
}
LogicalResult ScanConverter::convertToTargetOpExtended(
triton::ScanOp op, typename triton::ScanOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
bool reverse = op.getReverse();
auto operands = op->getOperands();
if (operands.empty()) {
return rewriter.notifyMatchFailure(op, "No input operands for extended scan");
}
llvm::SmallVector<RankedTensorType> inputTensTypes;
for (auto operand : operands) {
auto tensorTy = dyn_cast<RankedTensorType>(operand.getType());
if (!tensorTy) {
return rewriter.notifyMatchFailure(op, "All inputs must be RankedTensorType");
}
inputTensTypes.push_back(tensorTy);
}
auto baseShape = inputTensTypes[0].getShape();
int rank = baseShape.size();
int axis = op.getAxis();
if (axis < 0 || axis >= rank) {
return rewriter.notifyMatchFailure(op, "Invalid scan axis: " + std::to_string(axis));
}
for (size_t i = 1; i < inputTensTypes.size(); ++i) {
if (inputTensTypes[i].getShape() != baseShape) {
return rewriter.notifyMatchFailure(op, "All inputs must have the same shape");
}
}
llvm::SmallVector<Value> inputMemRefs;
llvm::SmallVector<Value> outputMemRefs;
llvm::SmallVector<MemRefType> memRefTypes;
for (size_t i = 0; i < inputTensTypes.size(); ++i) {
auto &tensorTy = inputTensTypes[i];
auto memRefTy = MemRefType::get(tensorTy.getShape(), tensorTy.getElementType());
memRefTypes.push_back(memRefTy);
inputMemRefs.push_back(rewriter.create<bufferization::ToBufferOp>(loc, memRefTy, operands[i]));
outputMemRefs.push_back(rewriter.create<memref::AllocOp>(loc, memRefTy));
}
LogicalResult loopResult = success();
auto processDimension = [&](ArrayRef<Value> baseIdxsArray) {
llvm::SmallVector<Value> baseIdxs(baseIdxsArray.begin(), baseIdxsArray.end());
auto startInd = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0);
if (reverse) {
startInd = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), baseShape[axis] - 1);
}
llvm::SmallVector<Value> firstIdx = baseIdxs;
if (axis <= firstIdx.size()) {
firstIdx.insert(firstIdx.begin() + axis, startInd);
} else {
firstIdx.push_back(startInd);
}
for (size_t i = 0; i < inputMemRefs.size(); ++i) {
Value firstVal = rewriter.create<memref::LoadOp>(loc, inputMemRefs[i], firstIdx);
rewriter.create<memref::StoreOp>(loc, firstVal, outputMemRefs[i], firstIdx);
}
Value axisSize = rewriter.create<arith::ConstantIndexOp>(loc, baseShape[axis]);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value cmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sgt, axisSize, one);
auto ifOp = rewriter.create<scf::IfOp>(loc, cmp, false);
rewriter.setInsertionPointToStart(ifOp.thenBlock());
auto forOp = rewriter.create<scf::ForOp>(loc, one, axisSize, one);
rewriter.setInsertionPointToStart(forOp.getBody());
Value k = forOp.getInductionVar();
if (reverse) {
Value axisSizeVal = rewriter.create<arith::ConstantIndexOp>(loc, baseShape[axis]);
Value axisSizeMinusOne = rewriter.create<arith::SubIOp>(loc, axisSizeVal, one);
k = rewriter.create<arith::SubIOp>(loc, axisSizeMinusOne, k);
}
llvm::SmallVector<Value> currIdx = baseIdxs;
if (axis <= currIdx.size()) {
currIdx.insert(currIdx.begin() + axis, k);
} else {
currIdx.push_back(k);
}
Value prevIndex;
if (reverse) {
prevIndex = rewriter.create<arith::AddIOp>(loc, k, one);
} else {
prevIndex = rewriter.create<arith::SubIOp>(loc, k, one);
}
llvm::SmallVector<Value> prevIdx = baseIdxs;
if (axis <= prevIdx.size()) {
prevIdx.insert(prevIdx.begin() + axis, prevIndex);
} else {
prevIdx.push_back(prevIndex);
}
llvm::SmallVector<Value> currentVals;
llvm::SmallVector<Value> prevResults;
for (size_t i = 0; i < inputMemRefs.size(); ++i) {
currentVals.push_back(rewriter.create<memref::LoadOp>(loc, inputMemRefs[i], currIdx));
prevResults.push_back(rewriter.create<memref::LoadOp>(loc, outputMemRefs[i], prevIdx));
}
Region &combineRegion = op->getRegion(0);
if (combineRegion.empty()) {
op->emitError("Missing combine region in extended scan");
loopResult = failure();
return;
}
Block &combineBlock = combineRegion.front();
if (combineBlock.getNumArguments() != 2 * inputMemRefs.size()) {
op->emitError("Combine region arguments mismatch with input count");
loopResult = failure();
return;
}
IRMapping mapping;
for (size_t i = 0; i < inputMemRefs.size(); ++i) {
mapping.map(combineBlock.getArgument(i), prevResults[i]);
mapping.map(combineBlock.getArgument(i + inputMemRefs.size()), currentVals[i]);
}
for (Operation &innerOp : combineBlock.without_terminator()) {
rewriter.clone(innerOp, mapping);
}
Operation *yieldOp = combineBlock.getTerminator();
if (yieldOp->getNumOperands() != outputMemRefs.size()) {
op->emitError("Combine region returns mismatch with output count");
loopResult = failure();
return;
}
for (size_t i = 0; i < outputMemRefs.size(); ++i) {
Value resultVal = mapping.lookup(yieldOp->getOperand(i));
rewriter.create<memref::StoreOp>(loc, resultVal, outputMemRefs[i], currIdx);
}
rewriter.setInsertionPointAfter(ifOp);
};
llvm::SmallVector<int> nonScanDims;
for (int i = 0; i < rank; ++i) {
if (i != axis) nonScanDims.push_back(i);
}
createSimpleNestedLoops(rewriter, loc, outputMemRefs[0], nonScanDims, processDimension);
if (failed(loopResult)) {
return failure();
}
llvm::SmallVector<Value> outputTensors;
for (auto outputMemRef : outputMemRefs) {
mlir::Type resultType = mlir::memref::getTensorTypeFromMemRefType(dyn_cast<mlir::MemRefType>(outputMemRef.getType()));
outputTensors.push_back(rewriter.create<bufferization::ToTensorOp>(loc, resultType, outputMemRef, true));
}
rewriter.replaceOp(op, outputTensors);
return success();
}
LogicalResult ExternElementwiseClOpConverter::matchAndRewrite(
triton::ExternElementwiseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
if (!op.getPure()) {
op->emitWarning() << "impure elementwise op!";
return failure();
}
if (op.getSymbol().contains("__hmf_")) {
Type dstTy = op.getResult().getType();
bool isDstScalar = !isa<RankedTensorType>(dstTy);
Type dstElemTy =
isDstScalar ? dstTy : cast<RankedTensorType>(dstTy).getElementType();
SmallVector<Type, 4> srcElemTys;
SmallVector<Value, 4> srcs;
for (auto src : op.getSrcs()) {
if (!isa<RankedTensorType>(src.getType())) {
src = rewriter.create<tensor::FromElementsOp>(
op.getLoc(), RankedTensorType::get({(int64_t)1}, src.getType()),
src);
}
srcs.push_back(src);
srcElemTys.push_back(
cast<RankedTensorType>(src.getType()).getElementType());
}
FunctionType elemFuncType =
FunctionType::get(rewriter.getContext(), srcElemTys, {dstElemTy});
auto mod = SymbolTable::getNearestSymbolTable(op);
auto extFunc = dyn_cast_or_null<SymbolOpInterface>(
SymbolTable::lookupSymbolIn(mod, op.getSymbol()));
bool is_libdevice = llvm::is_contained(libdeviceOps, op.getSymbol()) && getEnvBool("TRITON_ENABLE_LIBDEVICE_SIMT", false);
if (!extFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&mod->getRegion(0).front());
extFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(),
op.getSymbol(), elemFuncType);
extFunc.setPrivate();
extFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
UnitAttr::get(rewriter.getContext()));
if (is_libdevice) {
hivm::TFuncCoreType e = hivm::TFuncCoreType::AIV;
extFunc->setAttr(hivm::TFuncCoreTypeAttr::name,
hivm::TFuncCoreTypeAttr::get(extFunc->getContext(), e));
}
}
assert(isa<FunctionOpInterface>(
SymbolTable::lookupSymbolIn(mod, op.getSymbol())));
Value output;
if (isDstScalar) {
dstTy = RankedTensorType::get({(int64_t)1}, dstElemTy);
}
bool found = false;
for (Value v : srcs) {
if (v.getType() == dstTy) {
found = true;
output = v;
break;
}
}
if (!found) {
output = rewriter.create<tensor::EmptyOp>(
op.getLoc(), cast<RankedTensorType>(dstTy).getShape(), dstElemTy);
}
if (is_libdevice) {
auto srcType = cast<RankedTensorType>(srcs[0].getType());
SmallVector<Value> dimSizes;
int64_t rank = srcType.getRank();
for (int i = 0; i < rank; ++i) {
if (srcType.isDynamicDim(i)) {
auto dimOp = rewriter.create<tensor::DimOp>(loc, srcs[0], i);
dimSizes.push_back(dimOp);
} else {
auto constOp = rewriter.create<arith::ConstantIndexOp>(loc, srcType.getDimSize(i));
dimSizes.push_back(constOp);
}
}
std::function<Value(OpBuilder&, Location, SmallVector<Value>, Value)> buildLoops = [&](
OpBuilder &b, Location loc, SmallVector<Value> indices, Value acc) -> Value {
int64_t dim = indices.size();
if (dim == rank) {
SmallVector<Value> elemVals;
for (auto src : srcs) {
auto extract = b.create<tensor::ExtractOp>(loc, src, indices);
elemVals.push_back(extract);
}
auto call = b.create<func::CallOp>(loc, op.getSymbol(), dstElemTy, elemVals);
auto insert = b.create<tensor::InsertOp>(loc, call.getResult(0), acc, indices);
return insert;
} else {
Value lower = b.create<arith::ConstantIndexOp>(loc, 0);
Value upper = dimSizes[dim];
Value step = b.create<arith::ConstantIndexOp>(loc, 1);
auto loop = b.create<scf::ForOp>(loc, lower, upper, step, ValueRange{acc});
Block *body = loop.getBody();
OpBuilder innerBuilder = OpBuilder::atBlockBegin(body);
SmallVector<Value> newIndices = indices;
newIndices.push_back(loop.getInductionVar());
Value innerAcc = loop.getRegionIterArgs()[0];
Value updatedAcc = buildLoops(innerBuilder, loc, newIndices, innerAcc);
innerBuilder.create<scf::YieldOp>(loc, updatedAcc);
return loop.getResult(0);
}
};
Value result = buildLoops(rewriter, loc, {}, output);
if (isDstScalar) {
SmallVector<Value> zeroIndices(rank, rewriter.create<arith::ConstantIndexOp>(loc, 0));
auto extract = rewriter.create<tensor::ExtractOp>(loc, result, zeroIndices);
rewriter.replaceOp(op, extract);
} else {
rewriter.replaceOp(op, result);
}
return success();
}
auto mapOp = rewriter.create<linalg::MapOp>(
loc,
srcs,
output,
[&](OpBuilder &builder, Location loc, ValueRange regionArgs) {
auto elemOp = builder.create<func::CallOp>(loc,
op.getSymbol(),
dstElemTy,
regionArgs);
builder.create<linalg::YieldOp>(loc, elemOp->getResults());
});
if (isDstScalar) {
auto indexType = rewriter.getIndexType();
Value zeroConstant = rewriter.create<arith::ConstantOp>(
loc, indexType, rewriter.getIntegerAttr(indexType, 0));
auto extractOp = rewriter.create<tensor::ExtractOp>(
loc, mapOp.getResults()[0], zeroConstant);
rewriter.replaceOp(op, extractOp);
} else {
rewriter.replaceOp(op, mapOp);
}
return success();
}
return failure();
}
LogicalResult UnrealizedCastConverter::matchAndRewrite(
UnrealizedConversionCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
rewriter.eraseOp(op);
return success();
}
LogicalResult
JoinConverter::matchAndRewrite(triton::JoinOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value opa = op.getLhs();
Value opb = op.getRhs();
auto loc = op.getLoc();
auto resType = dyn_cast<RankedTensorType>(op.getResult().getType());
Value emptyOp = rewriter.create<tensor::EmptyOp>(loc, resType.getShape(),
resType.getElementType());
auto shape = dyn_cast<RankedTensorType>(opa.getType()).getShape();
auto sizes = llvm::map_to_vector(shape, [&](int64_t t) {
return OpFoldResult(rewriter.getI64IntegerAttr(t));
});
sizes.push_back(rewriter.getI64IntegerAttr(1));
int64_t rank = resType.getRank();
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
strides.back() = rewriter.getIndexAttr(2);
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
auto insert0 = rewriter.create<tensor::InsertSliceOp>(
loc, opa, emptyOp, offsets, sizes, strides);
offsets.back() = rewriter.getIndexAttr(1);
auto insert1 = rewriter.create<tensor::InsertSliceOp>(
loc, opb, insert0, offsets, sizes, strides);
rewriter.replaceOp(op, insert1);
return success();
}
LogicalResult
CatConverter::matchAndRewrite(triton::CatOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value opa = op.getLhs();
Value opb = op.getRhs();
auto loc = op.getLoc();
auto resType = dyn_cast<RankedTensorType>(op.getResult().getType());
if (!resType || resType.getRank() != 1) {
return rewriter.notifyMatchFailure(op, "only support 1D cat");
}
auto inputTypeA = dyn_cast<RankedTensorType>(opa.getType());
auto inputTypeB = dyn_cast<RankedTensorType>(opb.getType());
if (!inputTypeA || !inputTypeB || inputTypeA.getRank() != 1 ||
inputTypeB.getRank() != 1) {
return rewriter.notifyMatchFailure(op, "inputs must be 1D tensors");
}
int64_t sizeA = inputTypeA.getShape()[0];
int64_t sizeB = inputTypeB.getShape()[0];
if (sizeA == 1 && sizeB == 1) {
auto emptyOp = rewriter.create<tensor::EmptyOp>(
loc, resType.getShape(), resType.getElementType());
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);
Value scalarA = rewriter.create<tensor::ExtractOp>(loc, opa, zero);
Value scalarB = rewriter.create<tensor::ExtractOp>(loc, opb, zero);
Value inserted0 = rewriter.create<tensor::InsertOp>(loc, scalarA, emptyOp, zero);
Value inserted1 = rewriter.create<tensor::InsertOp>(loc, scalarB, inserted0, one);
rewriter.replaceOp(op, inserted1);
return success();
}
auto emptyOp = rewriter.create<tensor::EmptyOp>(loc, resType.getShape(),
resType.getElementType());
auto rank = resType.getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
auto inputType = dyn_cast<RankedTensorType>(opa.getType());
SmallVector<OpFoldResult> sizes =
llvm::map_to_vector(inputType.getShape(), [&](int64_t t) {
return OpFoldResult(rewriter.getI64IntegerAttr(t));
});
auto insert0 = rewriter.create<tensor::InsertSliceOp>(
loc, opa, emptyOp, offsets, sizes, strides);
offsets[0] =
rewriter.getIndexAttr(inputType.getRank() ? inputType.getShape()[0] : 1);
auto insert1 = rewriter.create<tensor::InsertSliceOp>(
loc, opb, insert0, offsets, sizes, strides);
rewriter.replaceOp(op, insert1);
return success();
}
LogicalResult
GatherConverter::matchAndRewrite(triton::GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = adaptor.getSrc();
Value idx = adaptor.getIndices();
Value res = op.getResult();
auto gatherAxis = op.getAxis();
auto moduleOp = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPoint(moduleOp.getBody(),
std::prev(moduleOp.getBody()->end()));
llvm::SmallString<kFuncNameCap> funcName = gatherFuncNameBase;
int uniqueId = 0;
while (SymbolTable::lookupSymbolIn(moduleOp, funcName)) {
funcName = gatherFuncNameBase;
funcName += ("_" + std::to_string(uniqueId++));
}
auto resTy = res.getType();
auto libFnType = rewriter.getFunctionType(
{src.getType(), idx.getType(), rewriter.getI32Type()}, {resTy});
auto funcOp = rewriter.create<func::FuncOp>(loc, funcName.str(), libFnType);
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
rewriter.setInsertionPoint(op);
Value axis = rewriter.create<arith::ConstantIntOp>(loc, gatherAxis, 32);
auto callOp = rewriter.create<func::CallOp>(loc, funcOp.getSymNameAttr(),
TypeRange({resTy}),
ValueRange({src, idx, axis}));
rewriter.replaceOp(op, callOp);
return success();
}
LogicalResult
SplitConverter::matchAndRewrite(triton::SplitOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
Value input = op.getSrc();
auto loc = op.getLoc();
auto inputType = cast<RankedTensorType>(input.getType());
int64_t rank = inputType.getRank();
SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
strides.back() = rewriter.getIndexAttr(2);
auto outType = dyn_cast<RankedTensorType>(op.getOutLHS().getType());
auto sizes = llvm::map_to_vector(outType.getShape(), [&](int64_t t) {
return OpFoldResult(rewriter.getIndexAttr(t));
});
sizes.push_back(rewriter.getIndexAttr(1));
auto slice0 = rewriter.create<tensor::ExtractSliceOp>(
loc, outType, input, offsets, sizes, strides);
offsets.back() = rewriter.getIndexAttr(1);
auto slice1 = rewriter.create<tensor::ExtractSliceOp>(
loc, outType, input, offsets, sizes, strides);
SmallVector<Value, 2> slices = {slice0.getResult(), slice1.getResult()};
rewriter.replaceOp(op, ValueRange(slices));
return success();
}
the element-wise most significant N bits of the 2N-bit product of x and y
%x:2 = arith.mulsi_extended %y, %z : tensor<4x?xi32>
*/
LogicalResult TritonMulhiuiConverter::matchAndRewrite(
triton::MulhiUIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value opl = op.getX();
Value opr = op.getY();
Value res = op.getResult();
auto newMulOp = rewriter.create<arith::MulUIExtendedOp>(
loc, res.getType(), res.getType(), opl, opr);
rewriter.replaceOp(op, ValueRange{newMulOp.getHigh()});
return success();
}
LogicalResult TritonPreciseSqrtConverter::matchAndRewrite(
triton::PreciseSqrtOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, adaptor.getOperands());
return success();
}
LogicalResult DevicePrintConverter::matchAndRewrite(
triton::PrintOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto moduleOp = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPoint(moduleOp.getBody(),
std::prev(moduleOp.getBody()->end()));
SmallVector<Type, 4> inputTypes;
for (auto arg : op.getArgs()) {
inputTypes.push_back(arg.getType());
}
auto libFnType = rewriter.getFunctionType(inputTypes, {});
auto funcOp =
rewriter.create<func::FuncOp>(op.getLoc(), printFuncNameBase, libFnType);
SymbolTable symTab(moduleOp);
auto maybePrintFuncNameAttr = symTab.renameToUnique(funcOp, {&symTab});
if (failed(maybePrintFuncNameAttr)) {
return op->emitError(
"failed to create a unique func name for device_print");
}
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
auto prefixAttr = op.getPrefixAttr();
funcOp->setAttr(prefixAttrName, prefixAttr);
auto hexAttr = op.getHexAttr();
funcOp->setAttr(hexAttrName, hexAttr);
rewriter.setInsertionPoint(op);
rewriter.create<func::CallOp>(op.getLoc(), funcOp, op.getArgs());
rewriter.eraseOp(op);
return success();
}
LogicalResult DeviceAssertConverter::matchAndRewrite(
triton::AssertOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
auto msgAttr = op.getMessageAttr();
if (auto strAttr = mlir::dyn_cast<mlir::StringAttr>(msgAttr)) {
llvm::StringRef msg = strAttr.getValue();
if (msg.contains("overflow detected for operation")) {
rewriter.eraseOp(op);
return success();
}
}
auto moduleOp = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPoint(moduleOp.getBody(),
std::prev(moduleOp.getBody()->end()));
auto conditionType = op.getCondition().getType();
auto libFnType = rewriter.getFunctionType({conditionType}, {});
auto funcOp =
rewriter.create<func::FuncOp>(op.getLoc(), printFuncNameBase, libFnType);
mlir::SymbolTable symTab(moduleOp);
auto maybePrintFuncNameAttr = symTab.renameToUnique(funcOp, {&symTab});
if (failed(maybePrintFuncNameAttr)) {
return op->emitError(
"failed to create a unique func name for device_assert");
}
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
funcOp->setAttr(msgAttrName, msgAttr);
rewriter.setInsertionPoint(op);
rewriter.create<func::CallOp>(op.getLoc(), funcOp, ValueRange{op.getCondition()});
rewriter.eraseOp(op);
return success();
}
LogicalResult
MatmulConverter::matchAndRewrite(triton::DotOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto opa = adaptor.getA();
auto opb = adaptor.getB();
auto opc = adaptor.getC();
auto dstType = cast<RankedTensorType>(op.getType());
auto elemTy = dstType.getElementType();
auto inputPrec = op.getInputPrecision();
auto createOp = [&](auto &&rewriter, ValueRange operands, ValueRange results) -> Operation* {
if (dstType.getRank() == 2)
return rewriter.template create<linalg::MatmulOp>(op.getLoc(), operands, results);
else if (dstType.getRank() == 3)
return rewriter.template create<linalg::BatchMatmulOp>(op.getLoc(), operands, results);
llvm_unreachable("Datatype of DotOp operands could only be 2D or 3D");
};
auto replaceOp = [&](auto &&rewriter, ValueRange operands, ValueRange results) -> Operation* {
if (dstType.getRank() == 2)
return rewriter.template replaceOpWithNewOp<linalg::MatmulOp>(op, operands, results);
else if (dstType.getRank() == 3)
return rewriter.template replaceOpWithNewOp<linalg::BatchMatmulOp>(op, operands, results);
llvm_unreachable("Datatype of DotOp operands could only be 2D or 3D");
};
Operation *matmulOp;
if (mlir::isa<mlir::FloatType>(elemTy) && !elemTy.isF32()) {
RankedTensorType opcFp32Ty = RankedTensorType::get(dstType.getShape(), rewriter.getF32Type());
Value opcFp32 = rewriter.create<arith::ExtFOp>(
op.getLoc(),
opcFp32Ty,
opc
);
matmulOp = createOp(rewriter, ValueRange{opa, opb}, ValueRange{opcFp32});
auto roundModeAttr = hfusion::RoundModeAttr::get(
rewriter.getContext(), hfusion::RoundMode::RINT);
auto truncOp = rewriter.replaceOpWithNewOp<arith::TruncFOp>(op, dstType, matmulOp->getResult(0));
truncOp->setAttr("round_mode", roundModeAttr);
} else {
matmulOp = replaceOp(rewriter, ValueRange{opa, opb}, ValueRange{opc});
}
matmulOp->setAttr(
"input_precision",
rewriter.getStringAttr(stringifyInputPrecision(inputPrec)));
return success();
}
LogicalResult FlipOpConverter::matchAndRewrite(triton::ascend::FlipOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
{
Value src = adaptor.getSrc();
auto rankedSrcTy = cast<RankedTensorType>(src.getType());
MLIRContext *ctx = rewriter.getContext();
Type valuesTy = src.getType();
Location loc = op.getLoc();
auto dimAttr = op->getAttrOfType<IntegerAttr>("dim");
if (!dimAttr) {
op->emitError("missing 'dim' attribute");
return failure();
}
auto moduleOp = op->getParentOfType<ModuleOp>();
if (!moduleOp) {
op->emitError("must be inside a module");
return failure();
}
std::string funcName = baseFuncName.str();
int uniqueId = 0;
while (SymbolTable::lookupSymbolIn(moduleOp, funcName))
funcName = (baseFuncName + Twine("_") + Twine(uniqueId++)).str();
auto i64Ty = IntegerType::get(ctx, 64);
auto libFnType = rewriter.getFunctionType({rankedSrcTy, i64Ty}, {rankedSrcTy});
auto moduleIP = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToEnd(moduleOp.getBody());
auto funcOp = rewriter.create<func::FuncOp>(loc, funcName, libFnType);
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
rewriter.restoreInsertionPoint(moduleIP);
Value dimVal = rewriter.create<arith::ConstantIntOp>(loc, dimAttr.getInt(), 64);
auto callee = SymbolRefAttr::get(ctx, funcOp.getSymName());
auto callOp = rewriter.create<func::CallOp>(loc, TypeRange({rankedSrcTy}), callee, ValueRange({src, dimVal}));
Value finalValues = callOp.getResult(0);
rewriter.replaceOp(op, {finalValues});
return success();
}
LogicalResult SortOpConverter::matchAndRewrite(
triton::ascend::SortOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
{
Value src = adaptor.getSrc();
auto rankedSrcTy = cast<RankedTensorType>(src.getType());
auto srcElemTy = rankedSrcTy.getElementType();
auto srcShape = rankedSrcTy.getShape();
auto srcEnc = rankedSrcTy.getEncoding();
MLIRContext *ctx = rewriter.getContext();
Type backendElemTy = srcElemTy;
if (srcElemTy.isInteger(8)) {
backendElemTy = Float16Type::get(ctx);
} else if (srcElemTy.isInteger(16)) {
backendElemTy = Float32Type::get(ctx);
}
Type backendTensorTy = RankedTensorType::get(srcShape, backendElemTy, srcEnc);
Type valuesTy = src.getType();
Location loc = op.getLoc();
auto dimAttr = op->getAttrOfType<IntegerAttr>("dim");
auto descAttr = op->getAttrOfType<BoolAttr>("descending");
if (!dimAttr || !descAttr) {
op->emitError("missing 'dim' or 'descending' attribute");
return failure();
}
auto moduleOp = op->getParentOfType<ModuleOp>();
if (!moduleOp) {
op->emitError("must be inside a module");
return failure();
}
llvm::SmallString<64> baseName("triton_sort");
llvm::SmallString<64> funcName = baseName;
int uniqueId = 0;
while (SymbolTable::lookupSymbolIn(moduleOp, funcName)) {
funcName = baseName;
funcName += ("_" + std::to_string(uniqueId++));
}
auto i64Ty = IntegerType::get(ctx, 64);
auto i1Ty = IntegerType::get(ctx, 1);
auto libFnType = rewriter.getFunctionType(
{backendTensorTy, i64Ty, i1Ty},
{backendTensorTy});
auto moduleIP = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToEnd(moduleOp.getBody());
auto funcOp = rewriter.create<func::FuncOp>(loc, funcName.str(), libFnType);
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
rewriter.restoreInsertionPoint(moduleIP);
Value srcForCall = src;
if (backendElemTy != srcElemTy) {
srcForCall = rewriter.create<arith::SIToFPOp>(loc, backendTensorTy, src);
}
Value dimVal = rewriter.create<arith::ConstantIntOp>(loc, dimAttr.getInt(), 64);
Value descVal = rewriter.create<arith::ConstantIntOp>(loc, descAttr.getValue() ? 1 : 0, 1);
auto callee = SymbolRefAttr::get(ctx, funcOp.getSymName());
auto callOp = rewriter.create<func::CallOp>(
loc,
TypeRange({backendTensorTy}),
callee,
ValueRange({srcForCall, dimVal, descVal})
);
Value valuesFloat = callOp.getResult(0);
Value finalValues = valuesFloat;
if (backendElemTy != srcElemTy) {
finalValues = rewriter.create<arith::FPToSIOp>(loc, valuesTy, valuesFloat);
}
rewriter.replaceOp(op, {finalValues});
return success();
}
LogicalResult
DotScaledConverter::matchAndRewrite(triton::DotScaledOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
{
Location loc = op.getLoc();
Value lhs = adaptor.getA();
Value rhs = adaptor.getB();
Value c = adaptor.getC();
Value lhsScale = adaptor.getAScale();
Value rhsScale = adaptor.getBScale();
RankedTensorType dstType = cast<RankedTensorType>(op.getType());
auto lhsElemType = op.getAElemType();
auto rhsElemType = op.getBElemType();
bool isFP8Input = (lhsElemType == triton::ScaleDotElemType::E4M3 ||
lhsElemType == triton::ScaleDotElemType::E5M2) &&
(rhsElemType == triton::ScaleDotElemType::E4M3 ||
rhsElemType == triton::ScaleDotElemType::E5M2);
bool isFP4Input = (lhsElemType == triton::ScaleDotElemType::E2M1) &&
(rhsElemType == triton::ScaleDotElemType::E2M1);
if (isFP8Input || isFP4Input) {
if (!rhsScale) {
RankedTensorType defaultScaleTy = RankedTensorType::get({1}, rewriter.getI8Type());
Value defaultScaleVal = rewriter.create<arith::ConstantOp>(loc, rewriter.getI8IntegerAttr(1));
Value defaultScaleEmpty = rewriter.create<tensor::EmptyOp>(loc, defaultScaleTy.getShape(), defaultScaleTy.getElementType());
rhsScale = rewriter.create<linalg::FillOp>(loc, ValueRange{defaultScaleVal}, ValueRange{defaultScaleEmpty}).getResult(0);
}
Value acc = c ? c : rewriter.create<tensor::EmptyOp>(loc, dstType.getShape(), dstType.getElementType());
auto convertFormat = [&](triton::ScaleDotElemType ty) -> mlir::hfusion::DataformatAttr {
auto ctx = rewriter.getContext();
switch (ty) {
case triton::ScaleDotElemType::E2M1:
return mlir::hfusion::DataformatAttr::get(ctx, mlir::hfusion::Dataformat::FP4E2M1_T);
case triton::ScaleDotElemType::E4M3:
return mlir::hfusion::DataformatAttr::get(ctx, mlir::hfusion::Dataformat::FP8E4M3_T);
case triton::ScaleDotElemType::E5M2:
return mlir::hfusion::DataformatAttr::get(ctx, mlir::hfusion::Dataformat::FP8E5M2_T);
default:
llvm_unreachable("unsupported ScaleDotElemType");
}
};
auto lhsFmt = convertFormat(lhsElemType);
auto rhsFmt = convertFormat(rhsElemType);
Value matmulMxResult = rewriter.create<hfusion::MatMulMxOp>(
loc,
dstType,
lhs,
rhs,
lhsScale,
rhsScale,
acc,
lhsFmt,
rhsFmt
);
Value finalResult = matmulMxResult;
if (dstType.getElementType().isBF16()) {
finalResult = rewriter.create<arith::TruncFOp>(loc, dstType, matmulMxResult);
}
rewriter.replaceOp(op, finalResult);
return success();
}
if (!lhsScale) {
return op.emitError("lhsScale is required for non-FP8 input");
}
RankedTensorType lhsTy = cast<RankedTensorType>(lhs.getType());
RankedTensorType lhsScaleTy = cast<RankedTensorType>(lhsScale.getType());
RankedTensorType rhsScaleTy = rhsScale ? cast<RankedTensorType>(rhsScale.getType()) : nullptr;
RankedTensorType rhsTy = cast<RankedTensorType>(rhs.getType());
Value lhsScaleOut;
Value rhsScaleOut;
Value c127 = rewriter.create<arith::ConstantOp>(
op.getLoc(),
rewriter.getI16Type(),
rewriter.getI16IntegerAttr(127)
);
Value c7 = rewriter.create<arith::ConstantOp>(
op.getLoc(),
rewriter.getI16Type(),
rewriter.getI16IntegerAttr(7)
);
Type i16Ty = rewriter.getI16Type();
Type bf16Ty = rewriter.getBF16Type();
Type fp16Ty = rewriter.getF16Type();
Type fp32Ty = rewriter.getF32Type();
bool fastMath = op.getFastMath();
auto createNanSplat = [&](RankedTensorType tensorTy) -> Value {
auto floatTy = cast<FloatType>(tensorTy.getElementType());
auto nanAttr = rewriter.getFloatAttr(
floatTy,
APFloat::getNaN(floatTy.getFloatSemantics())
);
Value empty = rewriter.create<tensor::EmptyOp>(loc, tensorTy.getShape(), tensorTy.getElementType());
return rewriter.create<linalg::FillOp>(loc, ValueRange{rewriter.create<arith::ConstantOp>(loc, nanAttr)}, ValueRange{empty}).getResult(0);
};
auto createNaNMask = [&](Value scaleTensor, RankedTensorType scaleTy) -> Value {
if (scaleTy.getElementType().isIntOrIndex()) {
auto bitWidth = scaleTy.getElementTypeBitWidth();
auto allOnes = APInt::getAllOnes(bitWidth);
auto sentinel = rewriter.create<arith::ConstantOp>(
loc,
scaleTy,
DenseElementsAttr::get(scaleTy, allOnes)
);
return rewriter.create<arith::CmpIOp>(
loc,
arith::CmpIPredicate::eq,
scaleTensor,
sentinel
).getResult();
}
return rewriter.create<arith::CmpFOp>(
loc,
arith::CmpFPredicate::UNO,
scaleTensor,
scaleTensor
).getResult();
};
auto applyNaNMask = [&](Value valueTensor, Value maskTensor) -> Value {
auto valueTy = cast<RankedTensorType>(valueTensor.getType());
Value nanTensor = createNanSplat(valueTy);
return rewriter.create<arith::SelectOp>(loc, maskTensor, nanTensor, valueTensor).getResult();
};
if (lhsScaleTy.getElementType().isIntOrIndex()) {
RankedTensorType lhsScaleI16Ty = RankedTensorType::get(lhsScaleTy.getShape(), i16Ty);
Value lhsScaleI16 = rewriter.create<arith::ExtSIOp>(
op.getLoc(),
lhsScaleI16Ty,
lhsScale
);
Value lhsShift127Empty = rewriter.create<tensor::EmptyOp>(
op.getLoc(),
lhsScaleI16Ty.getShape(),
i16Ty
);
Value lhsShift127 = rewriter.create<linalg::FillOp>(
op.getLoc(),
ValueRange{c127},
ValueRange{lhsShift127Empty}
).getResult(0);
Value lhsScaleI16Add127 = rewriter.create<arith::AddIOp>(
op.getLoc(),
lhsScaleI16,
lhsShift127
);
Value lhsShift7Empty = rewriter.create<tensor::EmptyOp>(
op.getLoc(),
lhsScaleI16Ty.getShape(),
i16Ty
);
Value lhsShift7 = rewriter.create<linalg::FillOp>(
op.getLoc(),
ValueRange{c7},
ValueRange{lhsShift7Empty}
).getResult(0);
Value lhsScaleI16Shifted = rewriter.create<arith::ShLIOp>(
op.getLoc(),
lhsScaleI16Add127,
lhsShift7
);
RankedTensorType lhsScaleBF16Ty = RankedTensorType::get(lhsScaleTy.getShape(), bf16Ty);
Value lhsScaleBF16 = rewriter.create<arith::BitcastOp>(
op.getLoc(),
lhsScaleBF16Ty,
lhsScaleI16Shifted
);
if (lhsTy.getElementType() == fp16Ty) {
RankedTensorType lhsScaleFp32Ty = RankedTensorType::get(lhsScaleTy.getShape(), fp32Ty);
Value lhsScaleFp32 = rewriter.create<arith::ExtFOp>(
op.getLoc(),
lhsScaleFp32Ty,
lhsScaleBF16
);
RankedTensorType lhsScaleFp16Ty = RankedTensorType::get(lhsScaleTy.getShape(), fp16Ty);
lhsScaleOut = rewriter.create<arith::TruncFOp>(
op.getLoc(),
lhsScaleFp16Ty,
lhsScaleFp32
);
} else {
lhsScaleOut = lhsScaleBF16;
}
} else {
lhsScaleOut = rewriter.create<arith::ExtFOp>(
op.getLoc(),
RankedTensorType::get(lhsScaleTy.getShape(), fp32Ty),
lhsScale
).getResult();
}
if (rhsScale && rhsScaleTy.getElementType().isIntOrIndex()) {
if (rhsScaleTy.getRank() != 2) {
return op.emitError("rhsScale must be 2D for transpose");
}
SmallVector<int64_t> transposedShape = {
rhsScaleTy.getShape()[1],
rhsScaleTy.getShape()[0]
};
RankedTensorType transposedRhsScaleTy = RankedTensorType::get(
transposedShape,
rhsScaleTy.getElementType()
);
Value transposedRhsScale = rewriter.create<triton::TransOp>(
op.getLoc(),
transposedRhsScaleTy,
rhsScale,
DenseI32ArrayAttr::get(
rewriter.getContext(),
ArrayRef<int32_t>{1, 0})
);
RankedTensorType rhsScaleI16Ty = RankedTensorType::get(
transposedShape,
i16Ty);
Value rhsScaleI16 = rewriter.create<arith::ExtSIOp>(
op.getLoc(),
rhsScaleI16Ty,
transposedRhsScale
);
Value rhsShift127Empty = rewriter.create<tensor::EmptyOp>(
op.getLoc(),
rhsScaleI16Ty.getShape(),
i16Ty
);
Value rhsShift127 = rewriter.create<linalg::FillOp>(
op.getLoc(),
ValueRange{c127},
ValueRange{rhsShift127Empty}
).getResult(0);
Value rhsScaleI16Add127 = rewriter.create<arith::AddIOp>(
op.getLoc(),
rhsScaleI16,
rhsShift127
);
Value rhsShift7Empty = rewriter.create<tensor::EmptyOp>(
op.getLoc(),
rhsScaleI16Ty.getShape(),
i16Ty
);
Value rhsShift7 = rewriter.create<linalg::FillOp>(
op.getLoc(),
ValueRange{c7},
ValueRange{rhsShift7Empty}
).getResult(0);
Value rhsScaleI16Shifted = rewriter.create<arith::ShLIOp>(
op.getLoc(),
rhsScaleI16Add127,
rhsShift7
);
RankedTensorType rhsScaleBF16Ty = RankedTensorType::get(transposedShape, bf16Ty);
Value rhsScaleBF16 = rewriter.create<arith::BitcastOp>(
op.getLoc(),
rhsScaleBF16Ty,
rhsScaleI16Shifted
);
if (rhsTy.getElementType() == fp16Ty) {
RankedTensorType rhsScaleFp32Ty = RankedTensorType::get(transposedShape, fp32Ty);
Value rhsScaleFp32 = rewriter.create<arith::ExtFOp>(
op.getLoc(),
rhsScaleFp32Ty,
rhsScaleBF16
);
RankedTensorType rhsScaleFp16Ty = RankedTensorType::get(transposedShape, fp16Ty);
rhsScaleOut = rewriter.create<arith::TruncFOp>(
op.getLoc(),
rhsScaleFp16Ty,
rhsScaleFp32
);
} else {
rhsScaleOut = rhsScaleBF16;
}
int64_t rhsD0 = rhsScaleTy.getShape()[1];
int64_t rhsD1 = rhsScaleTy.getShape()[0];
SmallVector<int64_t> rhsExpandedShape1 = {rhsD0, rhsD1, 1};
RankedTensorType rhsExpandedTy1 = RankedTensorType::get(rhsExpandedShape1, rhsTy.getElementType());
Value rhsExpanded1 = rewriter.create<triton::ExpandDimsOp>(
op.getLoc(),
rhsExpandedTy1,
rhsScaleOut,
rewriter.getI32IntegerAttr(2)
).getResult();
int64_t rhsDim1 = rhsTy.getShape()[0];
if (rhsDim1 % rhsD0 != 0) {
return op.emitError("rhs dim0 must be an integer multiple of rhsScale dim0");
}
int64_t rhsD2 = rhsDim1 / rhsD0;
SmallVector<int64_t> rhsBroadcastShape = {rhsD0, rhsD1, rhsD2};
RankedTensorType rhsBroadcastTy = RankedTensorType::get(rhsBroadcastShape, rhsTy.getElementType());
Value rhsBroadcasted = rewriter.create<triton::BroadcastOp>(
op.getLoc(),
rhsBroadcastTy,
rhsExpanded1
).getResult();
SmallVector<int32_t> transposeOrder = {0, 2, 1};
Value transposedBroadcasted = rewriter.create<triton::TransOp>(
op.getLoc(),
RankedTensorType::get({rhsD0, rhsD2, rhsD1}, rhsTy.getElementType()),
rhsBroadcasted,
DenseI32ArrayAttr::get(rewriter.getContext(), transposeOrder)
);
SmallVector<ReassociationIndices> rhsReassociation;
rhsReassociation.push_back({0, 1});
rhsReassociation.push_back({2});
Value scaledRhs = rewriter.create<tensor::CollapseShapeOp>(
op.getLoc(),
RankedTensorType::get({rhsD0 * rhsD2, rhsD1}, rhsTy.getElementType()),
transposedBroadcasted,
rhsReassociation
).getResult();
rhs = rewriter.create<arith::MulFOp>(
op.getLoc(),
rhs,
scaledRhs
).getResult();
if (!fastMath) {
Value rhsScaleNaNMask = createNaNMask(transposedRhsScale, transposedRhsScaleTy);
Value rhsExpandedMask = rewriter.create<triton::ExpandDimsOp>(
op.getLoc(),
RankedTensorType::get(rhsExpandedShape1, rewriter.getI1Type()),
rhsScaleNaNMask,
rewriter.getI32IntegerAttr(2)
).getResult();
Value rhsBroadcastMask = rewriter.create<triton::BroadcastOp>(
op.getLoc(),
RankedTensorType::get(rhsBroadcastShape, rewriter.getI1Type()),
rhsExpandedMask
).getResult();
Value transposedBroadcastMask = rewriter.create<triton::TransOp>(
op.getLoc(),
RankedTensorType::get({rhsD0, rhsD2, rhsD1}, rewriter.getI1Type()),
rhsBroadcastMask,
DenseI32ArrayAttr::get(rewriter.getContext(), transposeOrder)
).getResult();
Value collapsedRhsMask = rewriter.create<tensor::CollapseShapeOp>(
op.getLoc(),
RankedTensorType::get({rhsD0 * rhsD2, rhsD1}, rewriter.getI1Type()),
transposedBroadcastMask,
rhsReassociation
).getResult();
rhs = applyNaNMask(rhs, collapsedRhsMask);
}
}
int64_t D0 = lhsScaleTy.getShape()[0];
int64_t D1 = lhsScaleTy.getShape()[1];
SmallVector<int64_t> expandedShape1 = {D0, D1, 1};
RankedTensorType expandedTy1 = RankedTensorType::get(expandedShape1, lhsTy.getElementType());
Value expanded1 = rewriter.create<triton::ExpandDimsOp>(
op.getLoc(),
expandedTy1,
lhsScaleOut,
rewriter.getI32IntegerAttr(2)
).getResult();
int64_t lhsDim1 = lhsTy.getShape()[1];
if (lhsDim1 % D1 != 0) {
return op.emitError("lhs dim1 must be an integer multiple of lhsScale dim1");
}
int64_t D2 = lhsDim1 / D1;
SmallVector<int64_t> broadcastShape = {D0, D1, D2};
RankedTensorType broadcastTy = RankedTensorType::get(broadcastShape, lhsTy.getElementType());
Value broadcasted = rewriter.create<triton::BroadcastOp>(
op.getLoc(),
broadcastTy,
expanded1
).getResult();
SmallVector<ReassociationIndices> reassociation;
reassociation.push_back({0});
reassociation.push_back({1, 2});
Value scaledLhs = rewriter.create<tensor::CollapseShapeOp>(
op.getLoc(),
RankedTensorType::get({D0, D1 * D2}, lhsTy.getElementType()),
broadcasted,
reassociation
).getResult();
Value scaledLhsFinal = rewriter.create<arith::MulFOp>(
op.getLoc(),
lhs,
scaledLhs
).getResult();
if (!fastMath) {
Value lhsScaleNaNMask = createNaNMask(lhsScale, lhsScaleTy);
Value lhsExpandedMask = rewriter.create<triton::ExpandDimsOp>(
op.getLoc(),
RankedTensorType::get(expandedShape1, rewriter.getI1Type()),
lhsScaleNaNMask,
rewriter.getI32IntegerAttr(2)
).getResult();
Value lhsBroadcastMask = rewriter.create<triton::BroadcastOp>(
op.getLoc(),
RankedTensorType::get(broadcastShape, rewriter.getI1Type()),
lhsExpandedMask
).getResult();
Value collapsedLhsMask = rewriter.create<tensor::CollapseShapeOp>(
op.getLoc(),
RankedTensorType::get({D0, D1 * D2}, rewriter.getI1Type()),
lhsBroadcastMask,
reassociation
).getResult();
scaledLhsFinal = applyNaNMask(scaledLhsFinal, collapsedLhsMask);
}
Operation *matmulOp;
if (dstType.getRank() == 2) {
matmulOp = rewriter.create<linalg::MatmulOp>(
op.getLoc(), ValueRange{scaledLhsFinal, rhs}, ValueRange{c}
);
} else if (dstType.getRank() == 3) {
matmulOp = rewriter.create<linalg::BatchMatmulOp>(
op.getLoc(), ValueRange{scaledLhsFinal, rhs}, ValueRange{c}
);
} else {
return op.emitError("DotScaledOp only support 2D or 3D tensor");
}
rewriter.replaceOp(op, matmulOp->getResults());
return success();
}
LogicalResult
PtrToIntConverter::matchAndRewrite(triton::PtrToIntOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value ptr = adaptor.getSrc();
if (!mlir::isa<MemRefType>(ptr.getType())) {
return rewriter.notifyMatchFailure(op, "input is not a memref type");
}
auto resultType = op.getType();
auto ptrToIndexOp = rewriter.create<memref::ExtractAlignedPointerAsIndexOp>(
loc, ptr);
Value intResult = rewriter.create<arith::IndexCastOp>(
loc, resultType, ptrToIndexOp);
rewriter.replaceOp(op, intResult);
return success();
}
LogicalResult
IndexPutConverter::matchAndRewrite(triton::ascend::IndexPutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
{
auto loc = op.getLoc();
auto moduleOp = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPoint(moduleOp.getBody(),
std::prev(moduleOp.getBody()->end()));
auto funcName = generateUniqueFuncName(moduleOp, funcNameBase);
auto ptr = adaptor.getPtr();
auto index = op.getIndex();
auto value = op.getValue();
auto dim = op.getDim();
auto indexBoundary = op.getIndexBoundary();
auto endOffset = op.getEndOffset();
auto startOffset = op.getStartOffset();
auto dstStride = adaptor.getDstStride();
auto ptrTy = dyn_cast<MemRefType>(ptr.getType());
if (!ptrTy) {
return rewriter.notifyMatchFailure(op, "expected MemRefType for ptr");
}
SmallVector<Type> inputTypes({ptrTy, index.getType(), value.getType(),
dim.getType(), indexBoundary.getType()});
inputTypes.append(endOffset.getTypes().begin(), endOffset.getTypes().end());
inputTypes.append(startOffset.getTypes().begin(), startOffset.getTypes().end());
inputTypes.append(dstStride.getTypes().begin(), dstStride.getTypes().end());
auto libFnType = rewriter.getFunctionType(inputTypes, {});
auto funcOp = rewriter.create<func::FuncOp>(loc, funcName.str(), libFnType);
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
rewriter.setInsertionPoint(op);
SmallVector<Value> inputVals({ptr, index, value, dim, indexBoundary});
inputVals.append(endOffset.begin(), endOffset.end());
inputVals.append(startOffset.begin(), startOffset.end());
inputVals.append(dstStride.begin(), dstStride.end());
rewriter.create<func::CallOp>(loc, funcOp.getSymNameAttr(),
TypeRange({}), inputVals);
rewriter.eraseOp(op);
return success();
}
LogicalResult
GatherOutToUbConverter::matchAndRewrite(triton::ascend::GatherOutToUbOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
{
auto loc = op.getLoc();
auto moduleOp = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPoint(moduleOp.getBody(),
std::prev(moduleOp.getBody()->end()));
auto funcName = generateUniqueFuncName(moduleOp, funcNameBase);
auto src = adaptor.getSrc();
auto index = op.getIndex();
auto indexBoundary = op.getIndexBoundary();
auto dim = op.getDim();
auto srcStride = op.getSrcStride();
auto endOffset = op.getEndOffset();
auto startOffset = op.getStartOffset();
auto other = op.getOther();
auto res = op.getResult();
auto resTy = res.getType();
auto srcTy = dyn_cast<MemRefType>(src.getType());
if (!srcTy) {
return rewriter.notifyMatchFailure(op, "expected MemRefType for src");
}
SmallVector<Type> inputTypes({srcTy, index.getType(), indexBoundary.getType(), dim.getType()});
inputTypes.append(srcStride.getTypes().begin(), srcStride.getTypes().end());
inputTypes.append(endOffset.getTypes().begin(), endOffset.getTypes().end());
inputTypes.append(startOffset.getTypes().begin(), startOffset.getTypes().end());
if (other) inputTypes.push_back(other.getType());
auto libFnType = rewriter.getFunctionType(inputTypes, {resTy});
auto funcOp = rewriter.create<func::FuncOp>(loc, funcName.str(), libFnType);
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
rewriter.setInsertionPoint(op);
SmallVector<Value> inputVals({src, index, indexBoundary, dim});
inputVals.append(srcStride.begin(), srcStride.end());
inputVals.append(endOffset.begin(), endOffset.end());
inputVals.append(startOffset.begin(), startOffset.end());
if (other) inputVals.push_back(other);
auto callOp = rewriter.create<func::CallOp>(loc, funcOp.getSymNameAttr(),
TypeRange({resTy}),
inputVals);
rewriter.replaceOp(op, callOp);
return success();
}
LogicalResult
ScatterUbToOutConverter::matchAndRewrite(triton::ascend::ScatterUbToOutOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
{
auto loc = op.getLoc();
auto moduleOp = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPoint(moduleOp.getBody(),
std::prev(moduleOp.getBody()->end()));
auto funcName = generateUniqueFuncName(moduleOp, funcNameBase);
auto ptr = adaptor.getPtr();
auto value = op.getValue();
auto index = op.getIndex();
auto indexBoundary = op.getIndexBoundary();
auto dim = op.getDim();
auto dstStride = op.getDstStride();
auto endOffset = op.getEndOffset();
auto startOffset = op.getStartOffset();
auto ptrTy = dyn_cast<MemRefType>(ptr.getType());
if (!ptrTy) {
return rewriter.notifyMatchFailure(op, "expected MemRefType for ptr");
}
SmallVector<Type> inputTypes({ptrTy, value.getType(), index.getType(),
indexBoundary.getType(), dim.getType()});
inputTypes.append(dstStride.getTypes().begin(), dstStride.getTypes().end());
inputTypes.append(endOffset.getTypes().begin(), endOffset.getTypes().end());
inputTypes.append(startOffset.getTypes().begin(), startOffset.getTypes().end());
auto libFnType = rewriter.getFunctionType(inputTypes, {});
auto funcOp = rewriter.create<func::FuncOp>(loc, funcName.str(), libFnType);
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
rewriter.setInsertionPoint(op);
SmallVector<Value> inputVals({ptr, value, index, indexBoundary, dim});
inputVals.append(dstStride.begin(), dstStride.end());
inputVals.append(endOffset.begin(), endOffset.end());
inputVals.append(startOffset.begin(), startOffset.end());
rewriter.create<func::CallOp>(loc, funcOp.getSymNameAttr(),
TypeRange({}), inputVals);
rewriter.eraseOp(op);
return success();
}
LogicalResult
IndirectLoadConverter::matchAndRewrite(triton::ascend::IndirectLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
{
auto loc = op.getLoc();
auto moduleOp = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPoint(moduleOp.getBody(),
std::prev(moduleOp.getBody()->end()));
auto funcName = generateUniqueFuncName(moduleOp, funcNameBase);
auto src = adaptor.getSrc();
auto offsets = op.getOffsets();
auto mask = op.getMask();
auto other = op.getOther();
auto res = op.getResult();
auto resTy = res.getType();
auto srcTy = dyn_cast<MemRefType>(src.getType());
if (!srcTy) {
return rewriter.notifyMatchFailure(op, "expected MemRefType for src");
}
SmallVector<Type> inputTypes({srcTy, offsets.getType()});
if (mask) inputTypes.push_back(mask.getType());
if (other) inputTypes.push_back(other.getType());
auto libFnType = rewriter.getFunctionType(inputTypes, {resTy});
auto funcOp = rewriter.create<func::FuncOp>(loc, funcName.str(), libFnType);
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
rewriter.setInsertionPoint(op);
SmallVector<Value> inputVals({src, offsets});
if (mask) inputVals.push_back(mask);
if (other) inputVals.push_back(other);
auto callOp = rewriter.create<func::CallOp>(loc, funcOp.getSymNameAttr(),
TypeRange({resTy}),
inputVals);
rewriter.replaceOp(op, callOp);
return success();
}
LogicalResult
IndirectStoreConverter::matchAndRewrite(triton::ascend::IndirectStoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const
{
auto loc = op.getLoc();
auto moduleOp = op->getParentOfType<ModuleOp>();
rewriter.setInsertionPoint(moduleOp.getBody(),
std::prev(moduleOp.getBody()->end()));
auto funcName = generateUniqueFuncName(moduleOp, funcNameBase);
auto src = adaptor.getSrc();
auto offsets = op.getOffsets();
auto value = op.getValue();
auto mask = op.getMask();
auto srcTy = dyn_cast<MemRefType>(src.getType());
if (!srcTy) {
return rewriter.notifyMatchFailure(op, "expected MemRefType for src");
}
SmallVector<Type> inputTypes({srcTy, offsets.getType(), value.getType()});
if (mask) inputTypes.push_back(mask.getType());
auto libFnType = rewriter.getFunctionType(inputTypes, {});
auto funcOp = rewriter.create<func::FuncOp>(loc, funcName.str(), libFnType);
SymbolTable::setSymbolVisibility(funcOp, SymbolTable::Visibility::Private);
rewriter.setInsertionPoint(op);
SmallVector<Value> inputVals({src, offsets, value});
if (mask) inputVals.push_back(mask);
rewriter.create<func::CallOp>(loc, funcOp.getSymNameAttr(),
TypeRange({}), inputVals);
rewriter.eraseOp(op);
return success();
}
IndexSelectSimdConverter::IndexSelectSimdConverter(MLIRContext *context)
: OpConversionPattern<triton::ascend::IndexSelectSimdOp>(context) {}
LogicalResult
IndexSelectSimdConverter::matchAndRewrite(triton::ascend::IndexSelectSimdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
Value src = adaptor.getSrc();
Value indexTensor = adaptor.getIndex();
auto srcShapeVals = adaptor.getSrcShape();
auto srcOffsetVals = adaptor.getSrcOffset();
auto readShapeAttr = op.getReadShape();
int32_t dim = op.getDim();
auto resultTensorType = cast<RankedTensorType>(op.getResult().getType());
auto elemType = resultTensorType.getElementType();
auto resultShape = resultTensorType.getShape();
auto srcMemRefType = cast<MemRefType>(src.getType());
ArrayRef<int32_t> readShape = readShapeAttr;
auto toIndexValue = [&](Value val) -> Value {
if (!val.getType().isIndex()) {
return rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), val);
}
return val;
};
SmallVector<int64_t> fullSrcShape;
SmallVector<int64_t> staticSizes;
SmallVector<Value> sizes;
for (size_t i = 0; i < srcShapeVals.size(); ++i) {
bool isDynamicDim = (i == static_cast<size_t>(dim) && readShape[i] == -1);
int64_t staticSize;
if (isDynamicDim) {
staticSize = ShapedType::kDynamic;
fullSrcShape.push_back(ShapedType::kDynamic);
sizes.push_back(toIndexValue(srcShapeVals[i]));
} else if (auto constOp = srcShapeVals[i].getDefiningOp<arith::ConstantIndexOp>()) {
staticSize = constOp.value();
fullSrcShape.push_back(staticSize);
} else {
staticSize = ShapedType::kDynamic;
fullSrcShape.push_back(ShapedType::kDynamic);
sizes.push_back(toIndexValue(srcShapeVals[i]));
}
staticSizes.push_back(staticSize);
}
auto fullSrcMemRefType = MemRefType::get(fullSrcShape, elemType);
SmallVector<Value> offsets, strides;
SmallVector<int64_t> staticOffsets, staticStrides;
staticOffsets.push_back(0);
for (size_t i = 0; i < srcShapeVals.size(); ++i) {
int64_t staticStride = 1;
bool isDynamic = false;
for (size_t j = i + 1; j < srcShapeVals.size(); ++j) {
if (staticSizes[j] == ShapedType::kDynamic) {
isDynamic = true;
break;
}
staticStride *= staticSizes[j];
}
if (isDynamic) {
staticStride = ShapedType::kDynamic;
Value strideVal = rewriter.create<arith::ConstantIndexOp>(loc, 1);
for (size_t j = i + 1; j < srcShapeVals.size(); ++j) {
if (staticSizes[j] != ShapedType::kDynamic) {
strideVal = rewriter.create<arith::MulIOp>(
loc, strideVal,
rewriter.create<arith::ConstantIndexOp>(loc, staticSizes[j]));
} else {
strideVal = rewriter.create<arith::MulIOp>(
loc, strideVal, toIndexValue(srcShapeVals[j]));
}
}
strides.push_back(strideVal);
}
staticStrides.push_back(staticStride);
}
auto srcMemRef = rewriter.create<memref::ReinterpretCastOp>(
loc, fullSrcMemRefType, src, offsets, sizes, strides,
staticOffsets, staticSizes, staticStrides);
auto resultMemRefType = MemRefType::get(resultShape, elemType);
auto outputBuffer = rewriter.create<memref::AllocOp>(loc, resultMemRefType);
auto indicesTensorType = cast<RankedTensorType>(indexTensor.getType());
int64_t numIndices = indicesTensorType.getShape()[0];
auto zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
auto numIndicesVal = rewriter.create<arith::ConstantIndexOp>(loc, numIndices);
auto stepOne = rewriter.create<arith::ConstantIndexOp>(loc, 1);
auto forOp = rewriter.create<scf::ForOp>(loc, zeroIdx, numIndicesVal, stepOne);
forOp->setAttr("hivm.parallel_loop", rewriter.getUnitAttr());
Block *loopBody = forOp.getBody();
auto savedInsertionPoint = rewriter.saveInsertionPoint();
rewriter.setInsertionPointToStart(loopBody);
Operation *terminator = &loopBody->back();
rewriter.setInsertionPoint(terminator);
Value iv = forOp.getInductionVar();
Value selectedIdx = rewriter.create<tensor::ExtractOp>(loc, indexTensor, ValueRange{iv});
Value selectedIdxAsIndex = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), selectedIdx);
SmallVector<OpFoldResult> srcSubviewOffsets, srcSubviewSizes, srcSubviewStrides;
for (size_t i = 0; i < srcOffsetVals.size(); ++i) {
if (i == static_cast<size_t>(dim)) {
srcSubviewOffsets.push_back(selectedIdxAsIndex);
srcSubviewSizes.push_back(rewriter.getIndexAttr(1));
} else {
Value offsetVal = srcOffsetVals[i];
if (!offsetVal.getType().isIndex()) {
offsetVal = rewriter.create<arith::IndexCastOp>(
loc, rewriter.getIndexType(), offsetVal);
}
srcSubviewOffsets.push_back(offsetVal);
srcSubviewSizes.push_back(rewriter.getIndexAttr(readShape[i]));
}
srcSubviewStrides.push_back(rewriter.getIndexAttr(1));
}
auto srcSubview = rewriter.create<memref::SubViewOp>(
loc, srcMemRef, srcSubviewOffsets, srcSubviewSizes, srcSubviewStrides);
SmallVector<OpFoldResult> dstSubviewOffsets, dstSubviewSizes, dstSubviewStrides;
for (size_t i = 0; i < resultShape.size(); ++i) {
if (i == static_cast<size_t>(dim)) {
dstSubviewOffsets.push_back(iv);
dstSubviewSizes.push_back(rewriter.getIndexAttr(1));
} else {
dstSubviewOffsets.push_back(rewriter.getIndexAttr(0));
dstSubviewSizes.push_back(rewriter.getIndexAttr(readShape[i]));
}
dstSubviewStrides.push_back(rewriter.getIndexAttr(1));
}
auto dstSubview = rewriter.create<memref::SubViewOp>(
loc, outputBuffer, dstSubviewOffsets, dstSubviewSizes, dstSubviewStrides);
if (static_cast<size_t>(dim) == fullSrcShape.size() - 1) {
auto copyOp = rewriter.create<memref::CopyOp>(loc, srcSubview, dstSubview);
copyOp->setAttr(ConverterUtils::discreteAttrName,
rewriter.getUnitAttr());
} else {
auto dstMarkOp = rewriter.create<annotation::MarkOp>(loc, dstSubview);
dstMarkOp->setAttr("hfusion.stride_align_dims",
rewriter.getDenseI32ArrayAttr({static_cast<int32_t>(dim)}));
dstMarkOp->setAttr("hfusion.stride_align_value_in_byte",
rewriter.getDenseI32ArrayAttr({32}));
rewriter.create<memref::CopyOp>(loc, srcSubview, dstSubview);
}
rewriter.restoreInsertionPoint(savedInsertionPoint);
auto resultTensor = rewriter.create<bufferization::ToTensorOp>(
loc, resultTensorType, outputBuffer, true, true);
resultTensor->setAttr("index_select_simd", rewriter.getUnitAttr());
rewriter.replaceOp(op, resultTensor);
return success();
}
}