* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "TritonToStructured/MemOpConverter.h"
#include <cassert>
#include <numeric>
#include <type_traits>
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/Debug.h"
#include "TritonToStructured/TritonToStructuredPass.h"
#include "TritonToStructured/PtrAnalysis.h"
#include "TritonToStructured/MaskAnalysis.h"
#include "TritonToStructured/CannonicalizerConverter.h"
#include "Utils/InterleaveOptimization.h"
#include "Utils/Utils.h"
#include "bishengir/Dialect/Annotation/IR/Annotation.h"
#include "bishengir/Dialect/HIVM/Utils/Utils.h"
#define DEBUG_TYPE "triton-mem-op-converter"
namespace MemOpConverter {
using namespace mlir;
using namespace triton;
using namespace TritonToStructured;
LogicalResult
LoadConverter::matchAndRewrite(triton::LoadOp op,
PatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto oldPtr = op.getPtr();
auto oldMask = op.getMask();
auto oldOther = op.getOther();
MemOpTransformer tf(
MemOpTransformer::MemType::load,
optimizeDynamicOffset
);
auto newPtr = tf.createNewPtr(oldPtr, loc, rewriter);
auto newMask = tf.createNewMask(oldMask, loc, rewriter);
auto newOther = tf.createNewOther(oldOther, loc, rewriter);
if (!tf.ptrState.shouldLinearize) {
return failure();
}
if (!newPtr) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "PtrAnalysis: failed to analyze load pointer.";
});
return failure();
}
if (!enableMaskFallbackConversion && oldMask && !newMask) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "MaskAnalysis: failed to analyze load mask.";
});
return failure();
}
auto loadOp = rewriter.create<triton::LoadOp>(loc, newPtr, newMask, newOther,
op.getCache(), op.getEvict(), op.getIsVolatile());
auto broadCastResult = tf.materializeImplicitBroadcast(
loadOp.getResult(), loc, rewriter);
auto permuteResult = tf.materializeImplicitPermute(
broadCastResult, loc, rewriter);
auto reshapeResult = tf.materializeImplicitReshape(
permuteResult, loc, rewriter);
auto selectResult = tf.materializeImplicitSelect(
reshapeResult, oldMask, oldOther, loc, rewriter);
rewriter.replaceOp(op, selectResult);
return success();
}
LogicalResult
StoreConverter::matchAndRewrite(triton::StoreOp op,
PatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto oldPtr = op.getPtr();
auto oldMask = op.getMask();
auto oldValue = op.getValue();
MemOpTransformer tf(
MemOpTransformer::MemType::store,
optimizeDynamicOffset
);
auto newPtr = tf.createNewPtr(oldPtr, loc, rewriter);
auto newMask = tf.createNewMask(oldMask, loc, rewriter);
if (!tf.ptrState.shouldLinearize) {
return failure();
}
if (!newPtr) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "PtrAnalysis: failed to analyze store pointer.";
});
return failure();
}
if (!enableMaskFallbackConversion && oldMask && !newMask) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "MaskAnalysis: failed to analyze store mask.";
});
return failure();
}
auto lockVar = createSyncBlockLockVar(rewriter, loc);
if (oldMask && !newMask) {
rewriter.create<hivm::SyncBlockLockOp>(loc, lockVar);
}
auto selectResult = tf.materializeImplicitSelect(
oldValue, oldMask, oldPtr, loc, rewriter);
auto reshapeResult = tf.materializeImplicitReshape(
selectResult, loc, rewriter);
auto permuteResult = tf.materializeImplicitPermute(
reshapeResult, loc, rewriter);
auto storeOp = rewriter.create<triton::StoreOp>(loc, newPtr, permuteResult, newMask,
op.getBoundaryCheck(), op.getCache(), op.getEvict());
if (oldMask && !newMask) {
rewriter.create<hivm::SyncBlockUnlockOp>(loc, lockVar);
}
rewriter.eraseOp(op);
return success();
}
Value MemOpTransformer::materializeImplicitBroadcast(Value srcTensor, const Location loc,
PatternRewriter &rewriter) {
SmallVector<int64_t> broadCastIndex;
SmallVector<int64_t> broadCastShape;
for (auto [i, info] : llvm::enumerate(ptrState.stateInfo)) {
if (isZero(info.stride)) {
broadCastIndex.emplace_back(i);
}
auto staticShape = getIntAttr(info.shape);
if (!staticShape.has_value()) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "PtrAnalysis: dynamic shape is not supported in broadcast\n";
});
return srcTensor;
}
broadCastShape.emplace_back(staticShape.value());
}
if (broadCastIndex.empty()) return srcTensor;
auto srcType = srcTensor.getType();
if (srcType.isIntOrFloat()) {
auto broadCastType = RankedTensorType::get(
broadCastShape, srcType
);
auto splatOp = rewriter.create<triton::SplatOp>(
loc, broadCastType, srcTensor
);
return splatOp.getResult();
}
auto init = rewriter.create<tensor::EmptyOp>(
loc, broadCastShape,
cast<ShapedType>(srcTensor.getType()).getElementType()
);
auto broadCastOp = rewriter.create<linalg::BroadcastOp>(
loc, srcTensor, init, broadCastIndex);
return broadCastOp->getResult(0);
}
Value MemOpTransformer::materializeImplicitReshape(Value srcTensor, const Location loc,
PatternRewriter &rewriter) {
if (ptrState.sizes.size() == ptrState.stateInfo.size())
return srcTensor;
SmallVector<int64_t> targetShape;
if (currentType == MemType::load) {
for (auto size : ptrState.sizes) {
auto staticShape = getIntAttr(size);
if (!staticShape.has_value()) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "PtrAnalysis: dynamic shape is not supported in reshape\n";
});
return srcTensor;
}
targetShape.emplace_back(staticShape.value());
}
} else {
for (auto info : ptrState.stateInfo) {
auto staticShape = getIntAttr(info.shape);
if (!staticShape.has_value()) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "PtrAnalysis: dynamic shape is not supported in reshape\n";
});
return srcTensor;
}
targetShape.emplace_back(staticShape.value());
}
}
auto targetShapeAttr = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(targetShape.size())},
rewriter.getI64Type()),
targetShape);
auto targetShapeType = RankedTensorType::get(
targetShape, cast<ShapedType>(srcTensor.getType()).getElementType());
auto targetShapeValue = rewriter.create<arith::ConstantOp>(loc, targetShapeAttr);
auto reshapeOp = rewriter.create<tensor::ReshapeOp>(
loc, targetShapeType, srcTensor, targetShapeValue);
return reshapeOp.getResult();
}
Value MemOpTransformer::materializeImplicitSelect(Value srcTensor, Value mask,
Value other, const Location loc,
PatternRewriter &rewriter) {
if (!mask || maskState.newMask) return srcTensor;
auto TensorType = cast<RankedTensorType>(srcTensor.getType());
if (cast<ShapedType>(mask.getType()).getShape() != TensorType.getShape()) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "MaskAnalysis: mask shape is not same as Value";
});
return srcTensor;
}
if (currentType == MemType::store) {
auto loadOp = rewriter.create<triton::LoadOp>(loc, other, nullptr, nullptr,
ArrayRef<int32_t>(), nullptr);
other = loadOp.getResult();
}
if (!other) {
auto elementType = TensorType.getElementType();
auto emptyOp = rewriter.create<tensor::EmptyOp>(
loc, TensorType.getShape(), elementType
);
other = emptyOp.getResult();
}
auto selectOp = rewriter.create<arith::SelectOp>(loc, mask, srcTensor, other);
return selectOp->getResult(0);
}
Value MemOpTransformer::materializeImplicitPermute(Value srcTensor, const Location loc,
PatternRewriter &rewriter) {
auto inTy = dyn_cast<RankedTensorType>(srcTensor.getType());
if (!inTy || !ptrState.isPermuted) return srcTensor;
auto inShape = inTy.getShape();
SmallVector<int32_t> order(ptrState.permuteIds.size());
for (size_t i = 0; i < ptrState.permuteIds.size(); ++i) {
if (currentType == MemType::load) {
order[ptrState.permuteIds[i]] = i;
} else {
order[i] = ptrState.permuteIds[i];
}
}
SmallVector<int64_t> outShape(order.size());
if (inShape.size() != outShape.size()) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "PtrAnalysis: incompatible shape for permute";
});
return srcTensor;
}
for (size_t i = 0; i < outShape.size(); ++i) {
outShape[i] = inShape[order[i]];
}
auto outTy = RankedTensorType::get(outShape, inTy.getElementType());
auto transOp = rewriter.create<triton::TransOp>(loc, outTy, srcTensor, order);
return transOp.getResult();
}
Value MemOpTransformer::createNewPtr(Value oldPtr,
const Location loc, PatternRewriter &rewriter) {
TritonToStructured::PtrAnalysis ptrAnalysis(optimizeDynamicOffset);
LLVM_DEBUG({
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "PtrAnalysis: analyzing load/store's ptr.\n";
});
if (ptrAnalysis.visitOperand(oldPtr, ptrState, loc, rewriter).failed()) {
ptrState.shouldLinearize = false;
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "PtranAlysis: failed to analyze load/store ptr.";
});
return oldPtr;
}
OpFoldResult maxStride = rewriter.getIndexAttr(1);
for (auto it = ptrState.stateInfo.rbegin(); it != ptrState.stateInfo.rend(); ++it) {
if (TritonToStructured::isOne(it->shape) &&
isZero(it->stride)) {
it->stride = maxStride;
}
maxStride = maxOpFoldResult(maxStride, it->stride, loc, rewriter);
}
for (auto it = ptrState.stateInfo.rbegin(); it != ptrState.stateInfo.rend(); ++it) {
if (isZero(it->stride)) {
ptrState.shouldLinearize = true;
}
}
ptrState.generateOriginPermuteIds();
return ptrState.createAddPtrOp(rewriter, loc);
}
Value MemOpTransformer::createNewMask(Value oldMask,
const Location loc, PatternRewriter &rewriter) {
if (!oldMask) return nullptr;
LLVM_DEBUG({
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "MaskAnalysis: analyzing load/store mask.\n";
});
if (!oldMask || maskState.analysisMask(oldMask).failed()) {
LLVM_DEBUG({
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "MaskAnalysis: no mask or failed to analyze mask.\n";
llvm::dbgs() << "oldMask:" << oldMask << "\n";
maskState.dump();
llvm::dbgs() << "----------------------------------------------\n";
});
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "MaskAnalysis: failed to analyze load/store mask.";
});
return nullptr;
}
SmallVector<TritonToStructured::dimInfo> newMaskInfo;
auto itPtr = ptrState.stateInfo.begin();
auto itMask = maskState.stateInfo.begin();
while (itPtr != ptrState.stateInfo.end() &&
itMask != maskState.stateInfo.end()) {
if (!isMultiple(itMask->shape, itPtr->shape)) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "MaskAnalysis: incompatible shapes between ptr and mask.";
llvm::dbgs() << "----------------------------------------------\n";
ptrState.dump();
llvm::dbgs() << "oldMask:" << oldMask << "\n";
maskState.dump();
llvm::dbgs() << "----------------------------------------------\n";
});
return nullptr;
}
auto newShape = minOpFoldResult(itMask->shape, itPtr->shape, loc, rewriter);
if (isLess(newShape, itMask->shape) && !itMask->hasBroadCast) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "MaskAnalysis: the mask shape is incompatible with ptr shape.";
});
return nullptr;
}
TritonToStructured::dimInfo newInfo(
itMask->offset, newShape, itMask->dimIndex,
itMask->hasBroadCast, itMask->currentType,
itMask->rhs
);
if (!isZero(itPtr->stride)) {
newMaskInfo.emplace_back(newInfo);
}
++itPtr;
if (isEqual(itMask->shape, newShape)) {
++itMask;
}
}
if (itPtr != ptrState.stateInfo.end() ||
itMask != maskState.stateInfo.end()) {
LLVM_DEBUG({
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "MaskAnalysis: failed to apply permute on mask.\n";
ptrState.dump();
llvm::dbgs() << "oldMask:" << oldMask << "\n";
maskState.dump();
llvm::dbgs() << "----------------------------------------------\n";
});
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "MaskAnalysis: incompatible number of dimensions between ptr and mask.";
});
return nullptr;
}
maskState.stateInfo = newMaskInfo;
if (ptrState.isPermuted && !applyPermuteOnMask()) {
LLVM_DEBUG({
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "MaskAnalysis: failed to apply permute on mask.\n";
ptrState.dump();
llvm::dbgs() << "oldMask:" << oldMask << "\n";
maskState.dump();
llvm::dbgs() << "----------------------------------------------\n";
InFlightDiagnostic diag =
emitWarning(loc) << "MaskAnalysis: failed to apply permute on mask.";
});
return nullptr;
}
LLVM_DEBUG({
llvm::dbgs() << "After matching MaskState: \n";
for (auto info : newMaskInfo) {
info.dump();
}
llvm::dbgs() << "----------------------------------------------\n";
});
auto newMask = maskState.createNewMask(loc, rewriter);
return newMask;
}
Value MemOpTransformer::createNewOther(Value oldOther,
const Location loc, PatternRewriter &rewriter) {
if (!oldOther || !maskState.newMask) return nullptr;
auto ptrType = dyn_cast<triton::PointerType>(ptrState.source.getType());
if (!ptrType) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "PtrAnalysis: source of ptrState is not a pointer type.";
});
return nullptr;
}
Type elementType = ptrType.getPointeeType();
SmallVector<int64_t> targetShape;
for (auto info : maskState.stateInfo) {
auto staticShape = getIntAttr(info.shape);
if (!staticShape.has_value()) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "MaskAnalysis: dynamic shape is not supported in reshape\n";
});
return oldOther;
}
targetShape.emplace_back(staticShape.value());
}
auto targetShapeAttr = DenseIntElementsAttr::get(
RankedTensorType::get({static_cast<int64_t>(targetShape.size())},
rewriter.getI64Type()),
targetShape);
auto targetShapeType = RankedTensorType::get(
targetShape, elementType);
auto targetShapeValue = rewriter.create<arith::ConstantOp>(loc, targetShapeAttr);
auto reshapeOp = rewriter.create<tensor::ReshapeOp>(
loc, targetShapeType, oldOther, targetShapeValue);
return reshapeOp.getResult();
}
bool MemOpTransformer::applyPermuteOnMask() {
if (!ptrState.isPermuted || maskState.isEmpty()) {
return true;
}
if (ptrState.permuteIds.size() != maskState.stateInfo.size()) {
return false;
}
SmallVector<TritonToStructured::dimInfo> newMaskInfo;
for (auto id : ptrState.permuteIds) {
newMaskInfo.push_back(maskState.stateInfo[id]);
}
maskState.stateInfo = newMaskInfo;
return true;
}
hivm::CreateSyncBlockLockOp createSyncBlockLockVar(OpBuilder &builder,
Location loc) {
SmallVector<int64_t> shape = {1};
auto elementType = builder.getI64Type();
Type memrefType = MemRefType::get(shape, elementType);
auto createSyncBlockLockOp =
builder.create<hivm::CreateSyncBlockLockOp>(loc, memrefType,
Value());
return createSyncBlockLockOp;
}
}