* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "ascend/include/TritonToLinalg/BlockPtrAnalysis.h"
#include "ascend/include/TritonToLinalg/LoadStoreConverter.h"
#include "ascend/include/TritonToLinalg/MaskAnalysis.h"
#include "ascend/include/TritonToLinalg/TritonToLinalgPass.h"
#include "ascend/include/Utils/InterleaveOptimization.h"
#include "ascend/include/Utils/Utils.h"
#include "bishengir/Dialect/Annotation/IR/Annotation.h"
#include "bishengir/Dialect/HFusion/IR/HFusion.h"
#include "bishengir/Dialect/HIVM/IR/HIVM.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "Utils/Utils.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/Debug.h"
#include <cassert>
#include <numeric>
#include <type_traits>
#define DEBUG_TYPE "triton-load-store-converter"
namespace LoadStoreConverter {
using namespace mlir;
using namespace triton;
const std::string MayImplicitTransposeWithLastAxisTAG = "MayImplicitTransposeWithLastAxis";
LogicalResult
AddPtrConverter::matchAndRewrite(triton::AddPtrOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
llvm::SmallDenseMap<Value, BlockData> known;
BlockDataParser::rewriteAddPtr(op, adaptor, rewriter, known);
return success();
}
LogicalResult LoadConverter::toTensorAndReplace(
triton::LoadOp &op, RankedTensorType &tensorType, Value localMem,
bool mayImplicitTransposeWithLastAxis, const Location &loc, ConversionPatternRewriter &rewriter) const {
Value loadedTensor = rewriter.create<bufferization::ToTensorOp>(loc, tensorType, localMem, true, true);
propagateWasBoolToInt8Attr(op.getOperation(), loadedTensor.getDefiningOp(), rewriter);
if(mayImplicitTransposeWithLastAxis){
auto markOp = rewriter.create<annotation::MarkOp>(loc, loadedTensor);
markOp->setAttr(MayImplicitTransposeWithLastAxisTAG, UnitAttr::get(rewriter.getContext()));
}
rewriter.replaceOp(op, loadedTensor);
return success();
}
LogicalResult
LoadConverter::checkModifiedByAddPtrConverter(triton::LoadOp &op) const {
if (!isa<scf::ForOp>(op->getParentOp())) {
return failure();
}
if (!op->hasAttr("IndirectLoad")) {
return failure();
}
auto ptrOp = op.getPtr().getDefiningOp();
auto ptrBlock = ptrOp->getBlock();
auto opBlock = op->getBlock();
if (ptrBlock == opBlock) {
return failure();
}
return success();
}
void LoadConverter::propagateWasBoolToInt8Attr(Operation *srcLoadOp, Operation *dstOp, PatternRewriter &rewriter) const
{
const std::string WasBoolToInt8TAG = "was_bool_to_int8";
if (!srcLoadOp || !dstOp) return;
if (srcLoadOp->hasAttr(WasBoolToInt8TAG)) {
dstOp->setAttr(WasBoolToInt8TAG, rewriter.getBoolAttr(true));
}
}
LogicalResult LoadConverter::continueModifyFromAddPtrConverter(
triton::LoadOp &op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto loc = op.getLoc();
auto forOp = op->getParentOfType<scf::ForOp>();
Operation *firstOp = &forOp.getBody()->front();
auto extractOp = cast<tensor::ExtractOp>(firstOp);
auto ivs = extractOp.getIndices();
auto iterArg = forOp.getRegionIterArg(0);
auto ptr = adaptor.getPtr();
rewriter.setInsertionPointAfter(op);
Value castVal = ptr.getDefiningOp<memref::ReinterpretCastOp>();
Value idxZero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value loadVal =
rewriter.create<memref::LoadOp>(loc, castVal, ValueRange{idxZero});
propagateWasBoolToInt8Attr(op.getOperation(), loadVal.getDefiningOp(), rewriter);
Value insertedVal =
rewriter.create<tensor::InsertOp>(loc, loadVal, iterArg, ValueRange{ivs});
Operation *terminator = forOp.getBody()->getTerminator();
scf::YieldOp oldYieldOp = cast<scf::YieldOp>(terminator);
auto yieldOp = rewriter.create<scf::YieldOp>(loc, ValueRange{insertedVal});
rewriter.replaceOp(oldYieldOp, yieldOp);
auto rank = cast<ShapedType>(op.getResult().getType()).getShape().size();
Operation *rootForOp = op;
while (rank != 0) {
rank--;
rootForOp = rootForOp->getParentOfType<scf::ForOp>();
}
rewriter.replaceOp(op, rootForOp);
LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(rootForOp) << "\n"; });
return success();
}
void LoadConverter::fillTensorWithOtherForMaskScenario(
Value other, Value localMem, ArrayRef<OpFoldResult> maskDim,
ConversionPatternRewriter &rewriter) const {
auto loc = localMem.getLoc();
MemRefType originalType = cast<MemRefType>(localMem.getType());
assert(originalType.hasStaticShape() && "only support static shape");
assert(originalType.getRank() == maskDim.size() &&
"shape and mask must have same rank");
auto fillFlag =
rewriter.create<arith::ConstantOp>(loc, rewriter.getBoolAttr(false))
.getResult();
for (size_t i = 0; i < originalType.getShape().size(); ++i) {
auto shapeVal = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIndexAttr(originalType.getDimSize(i)));
Value maskDimVal;
if (isa<Attribute>(maskDim[i]))
maskDimVal = rewriter.create<arith::ConstantOp>(
loc, cast<IntegerAttr>(cast<Attribute>(maskDim[i])));
else
maskDimVal = cast<Value>(maskDim[i]);
auto curCmp = rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::slt,
maskDimVal, shapeVal);
fillFlag = rewriter.create<arith::OrIOp>(loc, fillFlag, curCmp.getResult())
.getResult();
}
auto ifOp = rewriter.create<scf::IfOp>(loc, fillFlag);
{
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
rewriter.create<linalg::FillOp>(loc, ValueRange{other},
ValueRange{localMem});
}
ifOp->setAttr(
rewriter.getStringAttr("hivm.unlikely_condition"),
UnitAttr::get(rewriter.getContext())
);
}
LoadConverter::LoadConverter(MLIRContext *context)
: OpConversionPattern<triton::LoadOp>(context) {}
LogicalResult
LoadConverter::matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
if (checkModifiedByAddPtrConverter(op).succeeded()) {
return continueModifyFromAddPtrConverter(op, adaptor, rewriter);
}
auto ptr = adaptor.getPtr();
auto mask = op.getMask();
auto other = op.getOther();
auto loc = op.getLoc();
if (!isa<ShapedType>(op.getResult().getType())) {
auto scalarMemref =
BlockDataParser::getScalarMemRef(op.getPtr(), ptr, loc, rewriter);
auto resTy = op.getResult().getType();
auto idxZero =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
auto loadedValue = rewriter.create<memref::LoadOp>(loc, resTy, scalarMemref,
idxZero.getResult()).getResult();
propagateWasBoolToInt8Attr(op.getOperation(), loadedValue.getDefiningOp(), rewriter);
if (mask && other) {
mask = rewriter.create<triton::SplatOp>(loc, RankedTensorType::get({1}, mask.getType()), mask);
loadedValue = rewriter.create<triton::SplatOp>(loc, RankedTensorType::get({1}, loadedValue.getType()), loadedValue);
other = rewriter.create<triton::SplatOp>(loc, RankedTensorType::get({1}, other.getType()), other);
loadedValue = rewriter.create<arith::SelectOp>(loc, mask, loadedValue, other);
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, loadedValue, ValueRange({idxZero}));
} else {
rewriter.replaceOp(op, loadedValue);
}
return success();
}
int64_t lastStride=-1;
if (isa<BlockArgument>(ptr)) {
auto u = ptr;
while (auto blkArg = dyn_cast<BlockArgument>(u)) {
if (auto forOp = dyn_cast<scf::ForOp>(blkArg.getOwner()->getParentOp())) {
auto prt = forOp->getOperand(3+blkArg.getArgNumber()-1);
u = prt;
} else {
u=nullptr;
break;
}
}
if (u && isa<memref::ReinterpretCastOp>(u.getDefiningOp())) {
auto ret = mlir::ConverterUtils::getLastStrideOfReinterpretCastOp(dyn_cast<memref::ReinterpretCastOp>(u.getDefiningOp()));
if (ret.has_value()) lastStride = *ret;
}
}
auto memRefType = dyn_cast<MemRefType>(ptr.getType());
if (!memRefType) {
return rewriter.notifyMatchFailure(
op, "LoadOp expects a memref, not a memref of pointers");
}
if (!op->hasAttr(ConverterUtils::GeneratedByMakeTensorPtrTAG)) {
auto memrefOp = dyn_cast<memref::ReinterpretCastOp>(ptr.getDefiningOp());
auto ret = mlir::ConverterUtils::getLastStrideOfReinterpretCastOp(memrefOp);
if(ret.has_value())lastStride = *ret;
}
bool mayImplicitTransposeWithLastAxis = (existDotFlag) && (!op->hasAttr(ConverterUtils::GeneratedByMakeTensorPtrTAG)) &&
(lastStride != 1 && mlir::ConverterUtils::isaPermutedMemRefType(memRefType));
auto memRefShape = memRefType.getShape();
auto memRefElementType = memRefType.getElementType();
Value allocOp;
Value allocOpTmp;
if (op->hasAttr(ConverterUtils::discreteAttrName)) {
Operation *loop = op->getParentOp();
int extractedLoopCount = 1;
for (auto parentOp = loop->getParentOp();
parentOp->hasAttr("ExtractedLoadOrStore");
parentOp = parentOp->getParentOp()) {
loop = parentOp;
extractedLoopCount++;
}
rewriter.setInsertionPoint(loop);
auto loopOp = cast<scf::ForOp>(loop);
auto fullMemRefShape =
cast<RankedTensorType>(loopOp.getInitArgs()[0].getType()).getShape();
auto fullMemRefType = MemRefType::get(fullMemRefShape, memRefElementType);
bool isIndexSelectScenario = (extractedLoopCount == 1) && (fullMemRefShape.size() > 1u);
if (isIndexSelectScenario)
loopOp->setAttr("hivm.parallel_loop", rewriter.getUnitAttr());
allocOp = rewriter.create<memref::AllocOp>(loc, fullMemRefType);
allocOpTmp = allocOp;
rewriter.setInsertionPointAfter(loop);
auto toTensorOp = rewriter.create<bufferization::ToTensorOp>(
loc, RankedTensorType::get(fullMemRefShape, memRefElementType), allocOp, true, true);
rewriter.replaceAllUsesWith(loopOp->getResult(0), toTensorOp->getResult(0));
tensor::InsertSliceOp insertSliceOp = nullptr;
for (auto *user : op->getUsers()) {
if (auto targetOp = dyn_cast<tensor::InsertSliceOp>(user)) {
insertSliceOp = targetOp;
break;
}
}
auto offsets = insertSliceOp.getMixedOffsets();
auto sizes = insertSliceOp.getMixedSizes();
auto strides = insertSliceOp.getMixedStrides();
auto allocType = memref::SubViewOp::inferResultType(fullMemRefType, offsets,
sizes, strides);
rewriter.setInsertionPoint(op);
allocOp = rewriter.create<memref::SubViewOp>(
loc, cast<MemRefType>(allocType), allocOp, offsets, sizes, strides);
rewriter.replaceAllUsesExcept(insertSliceOp.getResult(),
insertSliceOp.getDest(), insertSliceOp);
rewriter.eraseOp(insertSliceOp);
} else {
allocOp = rewriter.create<memref::AllocOp>(
loc, MemRefType::get(memRefShape, memRefElementType));
}
auto tensorType = RankedTensorType::get(memRefShape, memRefElementType);
auto boundaryCheck = op.getBoundaryCheck();
if (!boundaryCheck.empty()) {
auto makeTensorPtrOp = op.getPtr().getDefiningOp<triton::MakeTensorPtrOp>();
auto boundarySizes = mlir::ConverterUtils::getBoundarySizes(
boundaryCheck, ptr, loc, rewriter);
auto padding = op.getPadding();
SmallVector<OpFoldResult> srcOffsets(boundarySizes.size(), rewriter.getIndexAttr(0));
SmallVector<OpFoldResult> dstOffsets;
if (makeTensorPtrOp) {
auto zeroVal = rewriter.createOrFold<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
for (auto [idx, offVal] : llvm::enumerate(makeTensorPtrOp.getOffsets())) {
if (llvm::find(boundaryCheck, idx) == boundaryCheck.end()) {
dstOffsets.push_back(srcOffsets[idx]);
continue;
}
Value offset = rewriter.createOrFold<arith::SubIOp>(loc, zeroVal, offVal);
Value size = getValueOrCreateConstantIndexOp(rewriter, loc, boundarySizes[idx]);
offset = rewriter.createOrFold<arith::MaxSIOp>(loc, offset, zeroVal);
offset = rewriter.createOrFold<arith::IndexCastOp>(loc, rewriter.getIndexType(), offset);
OpFoldResult ofr;
if (auto constOp = offset.getDefiningOp<arith::ConstantOp>()) {
ofr = constOp.getValue();
} else {
ofr = offset;
}
ofr = minOpFoldResult(ofr, size, loc, rewriter);
boundarySizes[idx] = subOpFoldResult(size, ofr, loc, rewriter);
dstOffsets.push_back(ofr);
}
} else {
dstOffsets = srcOffsets;
}
if (padding.has_value()) {
TypedAttr padAttr = rewriter.getZeroAttr(memRefElementType);
if (padding.value() == triton::PaddingOption::PAD_NAN) {
assert(!memRefElementType.isIntOrIndex());
auto apNaN = llvm::APFloat::getNaN(
cast<FloatAttr>(padAttr).getValue().getSemantics());
padAttr = rewriter.getFloatAttr(memRefElementType, apNaN);
}
auto padVal = rewriter.create<arith::ConstantOp>(loc, padAttr);
fillTensorWithOtherForMaskScenario(padVal, allocOp, boundarySizes,
rewriter);
}
auto srcSubView =
mlir::ConverterUtils::makeSubViewOp(ptr, srcOffsets, boundarySizes, loc, rewriter);
auto dstSubview = mlir::ConverterUtils::makeSubViewOp(
allocOp, dstOffsets, boundarySizes, loc, rewriter);
auto copyOp = rewriter.create<memref::CopyOp>(loc, srcSubView, dstSubview);
propagateWasBoolToInt8Attr(op.getOperation(), copyOp.getOperation(), rewriter);
if (mayImplicitTransposeWithLastAxis) {
auto markOp = rewriter.create<annotation::MarkOp>(loc, dstSubview);
markOp->setAttr(MayImplicitTransposeWithLastAxisTAG, UnitAttr::get(rewriter.getContext()));
}
return this->toTensorAndReplace(op, tensorType, allocOp, mayImplicitTransposeWithLastAxis, loc, rewriter);
}
if (!mask) {
assert(!other && "can not input 'other' when 'mask' is not set");
if (auto unrealizedCastOp =
ptr.getDefiningOp<UnrealizedConversionCastOp>()) {
op->emitError("meeting unexpected UCC in LoadConverter!");
return failure();
} else {
auto [ptrStrides, ptrOffsets] = memRefType.getStridesAndOffset();
if (ptrStrides.back() == 2 && (memRefShape.back() % 2 == 0) &&
mlir::triton::DeinterleaveStatusOptimization(op, adaptor, rewriter)
.succeeded()) {
return success();
}
auto copyOp = rewriter.create<memref::CopyOp>(loc, ptr, allocOp);
propagateWasBoolToInt8Attr(op.getOperation(), copyOp.getOperation(), rewriter);
if (mayImplicitTransposeWithLastAxis && allocOp.getDefiningOp<memref::AllocOp>()) {
auto markOp = rewriter.create<annotation::MarkOp>(loc, allocOp);
markOp->setAttr(MayImplicitTransposeWithLastAxisTAG, UnitAttr::get(rewriter.getContext()));
} else if (mayImplicitTransposeWithLastAxis && allocOp.getDefiningOp<memref::SubViewOp>()) {
auto markOp = rewriter.create<annotation::MarkOp>(loc, allocOpTmp);
markOp->setAttr(MayImplicitTransposeWithLastAxisTAG, UnitAttr::get(rewriter.getContext()));
}
}
return this->toTensorAndReplace(op, tensorType, allocOp, mayImplicitTransposeWithLastAxis, loc, rewriter);
}
MaskState mstate;
auto isContMask = mstate.parse(mask, loc, rewriter);
if (isContMask.failed()) {
return rewriter.notifyMatchFailure(
op, "can not lower uncontinuout masked loads");
}
if (other) {
auto scalarOther =
mlir::ConverterUtils::getScalarValue(other, loc, rewriter);
assert(
scalarOther &&
"other value used in masked load produced by unsupported instruction!");
fillTensorWithOtherForMaskScenario(scalarOther, allocOp, mstate.dims,
rewriter);
}
if (mstate.getRank() == memRefType.getRank() &&
isConstantIntValue(mstate.offsets.back(), 0) &&
isConstantIntValue(mstate.dims.back(), memRefType.getShape().back())) {
auto [ptrStrides, ptrOffsets] = memRefType.getStridesAndOffset();
if (ptrStrides.back() == 2 && (memRefType.getShape().back() % 2 == 0) &&
DeinterleaveStatusWithMaskOptimization(op, adaptor, rewriter, mstate,
allocOp)
.succeeded()) {
return success();
}
}
if (auto unrealizedCastOp = ptr.getDefiningOp<UnrealizedConversionCastOp>()) {
op->emitError("meeting unexpected UCC in LoadConverter!");
return failure();
} else {
if (mstate.isMemrefSubviewValid(ptr, rewriter)) {
memref::SubViewOp srcSubView = mstate.getSubview(ptr, loc, rewriter);
memref::SubViewOp dstSubView = mstate.getSubview(allocOp, loc, rewriter);
MemRefType dstSubViewType = mlir::cast<MemRefType>(dstSubView.getType());
auto [srcStrides, srcOffset] = dstSubViewType.getStridesAndOffset();
MemRefType castType = MemRefType::get(
dstSubViewType.getShape(),
dstSubViewType.getElementType(),
makeStridedLinearLayoutMap(srcStrides, srcOffset, rewriter.getContext())
);
auto castOp = rewriter.create<memref::CastOp>(loc, castType, dstSubView);
auto copyOp = rewriter.create<memref::CopyOp>(loc, srcSubView, castOp);
propagateWasBoolToInt8Attr(op.getOperation(), copyOp.getOperation(), rewriter);
}
if (mayImplicitTransposeWithLastAxis && allocOp.getDefiningOp<memref::AllocOp>()) {
auto markOp = rewriter.create<annotation::MarkOp>(loc, allocOp);
markOp->setAttr(MayImplicitTransposeWithLastAxisTAG, UnitAttr::get(rewriter.getContext()));
} else if (mayImplicitTransposeWithLastAxis && allocOp.getDefiningOp<memref::SubViewOp>()) {
auto markOp = rewriter.create<annotation::MarkOp>(loc, allocOpTmp);
markOp->setAttr(MayImplicitTransposeWithLastAxisTAG, UnitAttr::get(rewriter.getContext()));
}
}
return this->toTensorAndReplace(op, tensorType, allocOp, mayImplicitTransposeWithLastAxis, loc, rewriter);
}
AtomicRMWConverter::AtomicRMWConverter(MLIRContext *context)
: OpConversionPattern<triton::AtomicRMWOp>(context) {}
LogicalResult
AtomicRMWConverter::matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto ptr = adaptor.getPtr();
auto val = op.getVal();
auto loc = op.getLoc();
auto mask = op.getMask();
auto rmwOp = op.getAtomicRmwOp();
auto resType = dyn_cast<TensorType>(op.getResult().getType());
auto ptrType = dyn_cast<MemRefType>(ptr.getType());
if (!resType)
return rewriter.notifyMatchFailure(
op, "atomicRMWConverter: scalar will be handled by "
"ScalarAtomicRMWCanonicalizer");
if (!ptrType)
return rewriter.notifyMatchFailure(
op, "AtomicRMWOp expects a memref, not a memref of pointers");
const std::map<RMWOp, hivm::AtomicKind> atomicKindMap = {
{RMWOp::ADD, hivm::AtomicKind::ADD},
{RMWOp::FADD, hivm::AtomicKind::ADD},
{RMWOp::OR, hivm::AtomicKind::OR},
{RMWOp::XOR, hivm::AtomicKind::XOR},
{RMWOp::AND, hivm::AtomicKind::AND},
{RMWOp::MIN, hivm::AtomicKind::MIN},
{RMWOp::UMIN, hivm::AtomicKind::UMIN},
{RMWOp::MAX, hivm::AtomicKind::MAX},
{RMWOp::UMAX, hivm::AtomicKind::UMAX},
{RMWOp::XCHG, hivm::AtomicKind::XCHG},
};
const std::map<RMWOp, hfusion::AtomicKind> hfusionAtomicKindMap = {
{RMWOp::ADD, hfusion::AtomicKind::ADD},
{RMWOp::FADD, hfusion::AtomicKind::ADD},
{RMWOp::OR, hfusion::AtomicKind::OR},
{RMWOp::XOR, hfusion::AtomicKind::XOR},
{RMWOp::AND, hfusion::AtomicKind::AND},
{RMWOp::MIN, hfusion::AtomicKind::MIN},
{RMWOp::UMIN, hfusion::AtomicKind::UMIN},
{RMWOp::MAX, hfusion::AtomicKind::MAX},
{RMWOp::UMAX, hfusion::AtomicKind::UMAX},
{RMWOp::XCHG, hfusion::AtomicKind::XCHG},
};
assert(atomicKindMap.find(rmwOp) != atomicKindMap.end());
auto atomicKind = hivm::AtomicKindAttr::get(rewriter.getContext(), atomicKindMap.at(rmwOp));
assert(hfusionAtomicKindMap.find(rmwOp) != hfusionAtomicKindMap.end());
auto hfusionAtomicKind = hfusion::AtomicKindAttr::get(rewriter.getContext(), hfusionAtomicKindMap.at(rmwOp));
auto dstMemref = ptr;
Value inputVal = val;
auto getInputMemref = [&]() -> Value {
if (isa<MemRefType>(inputVal.getType()))
return inputVal;
return rewriter.create<bufferization::ToBufferOp>(loc, ptrType, inputVal);
};
auto inputMemref = getInputMemref();
auto inputMemrefType = cast<MemRefType>(inputMemref.getType());
auto elementType = inputMemrefType.getElementType();
auto isHardwareSupported =
(rmwOp == RMWOp::ADD || rmwOp == RMWOp::FADD || rmwOp == RMWOp::MAX || rmwOp == RMWOp::MIN) &&
(elementType.isF16() || elementType.isBF16() || elementType.isF32() || elementType.isInteger(8) ||
elementType.isInteger(16) || elementType.isInteger(32));
bool isDiscreteMask = false;
if (mask) {
auto constantMask = mask.getDefiningOp<arith::ConstantOp>();
if (constantMask && !isConstantMaskTrue(mask)) {
rewriter.eraseOp(op);
return success();
}
MaskState mstate;
isDiscreteMask = mstate.parse(mask, loc, rewriter).failed();
if (!constantMask && !isDiscreteMask) {
dstMemref = mstate.getSubview(ptr, loc, rewriter);
if (isHardwareSupported) {
auto inputTensorType = RankedTensorType::get(inputMemrefType.getShape(), inputMemrefType.getElementType());
if (!isa<RankedTensorType>(inputVal.getType()))
inputVal = rewriter.create<bufferization::ToTensorOp>(loc, inputTensorType, inputMemref, true, true);
inputVal = mstate.getExtractSlice(inputVal, loc, rewriter);
} else {
inputMemref = mstate.getSubview(inputMemref, loc, rewriter);
}
}
}
if (!op.getResult().use_empty()) {
auto tensorType =
RankedTensorType::get(ptrType.getShape(), ptrType.getElementType());
auto alloc = rewriter.create<memref::AllocOp>(
loc, MemRefType::get(ptrType.getShape(), ptrType.getElementType()));
rewriter.create<memref::CopyOp>(loc, ptr, alloc);
Value tensorToReplace = rewriter.create<bufferization::ToTensorOp>(
loc, tensorType, alloc, true , true );
rewriter.replaceOp(op, tensorToReplace);
}
if (isDiscreteMask) {
if (rmwOp != RMWOp::XCHG) {
return op.emitError("Discrete mask is only expected for XCHG; other atomics "
"should be lowered without discrete masks");
}
Value memrefMask = mask;
if (auto maskTypeT = dyn_cast<TensorType>(mask.getType())) {
MemRefType maskTypeM = MemRefType::get(maskTypeT.getShape(), maskTypeT.getElementType());
memrefMask =
rewriter.create<bufferization::ToBufferOp>(loc, maskTypeM, mask);
}
rewriter.create<hfusion::AtomicXchgOp>(op.getLoc(), TypeRange(), inputMemref, dstMemref, memrefMask);
} else {
if (isHardwareSupported)
rewriter.create<hivm::StoreOp>(op.getLoc(), TypeRange {}, inputVal, dstMemref, atomicKind);
else if (rmwOp == RMWOp::XCHG)
rewriter.create<hfusion::AtomicXchgOp>(op.getLoc(), TypeRange(), inputMemref, dstMemref);
else {
if (rmwOp == RMWOp::OR || rmwOp == RMWOp::XOR || rmwOp == RMWOp::AND) {
if (!elementType.isSignlessIntOrIndex()) {
return op->emitOpError() << "must be signless-integer-like, but got " << elementType;
}
}
if (compileOn91095Flag) {
rewriter.create<hfusion::StoreOp>(op.getLoc(), TypeRange {}, ValueRange {inputMemref}, ValueRange {dstMemref},
hfusionAtomicKind, ArrayRef<NamedAttribute> {});
} else {
rewriter.create<hfusion::AtomicRMWOp>(op.getLoc(), TypeRange(), inputMemref, dstMemref, hfusionAtomicKind);
}
}
}
if (op.getResult().use_empty()) {
rewriter.eraseOp(op);
}
return success();
}
LogicalResult
AtomicCASConverter::matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto ptr = adaptor.getPtr();
auto cmp = op.getCmp();
auto val = op.getVal();
auto loc = op.getLoc();
auto resType = dyn_cast<TensorType>(op.getResult().getType());
if (!resType) {
return rewriter.notifyMatchFailure(
op, "atomicCASConverter: scalar will be handled by "
"ScalarAtomicCASCanonicalizer");
}
auto type = dyn_cast<MemRefType>(ptr.getType());
if (!type) {
return rewriter.notifyMatchFailure(
op, "AtomicCASOp expects a memref, not a memref of pointers");
}
auto dstMemref = ptr;
auto dstOriType = cast<MemRefType>(dstMemref.getType());
MemRefType dstType = MemRefType::get(dstOriType.getShape(), dstOriType.getElementType());
Value inputMemref =
rewriter.create<bufferization::ToBufferOp>(loc, dstType, val);
Value cmpMemref =
rewriter.create<bufferization::ToBufferOp>(loc, dstType, cmp);
int64_t rank = type.getRank();
SmallVector<AffineExpr> inputDims;
auto context = rewriter.getContext();
for (int i = 0; i < rank; i++) {
inputDims.push_back(getAffineDimExpr(i, context));
}
SmallVector<AffineMap> indexingMaps;
for (int i = 0; i < 4; i++) {
indexingMaps.push_back(AffineMap::get(rank, 0, inputDims, context));
}
if (!op.getResult().use_empty()) {
auto tensorType =
RankedTensorType::get(type.getShape(), type.getElementType());
auto alloc = rewriter.create<memref::AllocOp>(
loc, MemRefType::get(type.getShape(), type.getElementType()));
rewriter.create<memref::CopyOp>(loc, ptr, alloc);
Value tensor = rewriter.create<bufferization::ToTensorOp>(
loc, tensorType, alloc, true , true );
rewriter.replaceOp(op, tensor);
}
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, ValueRange{dstMemref, cmpMemref, inputMemref},
mlir::ValueRange{dstMemref}, indexingMaps,
mlir::ConverterUtils::getNParallelLoopsAttrs(rank),
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
Value lhs = blockArgs[0];
Value rhs = blockArgs[1];
Value setValue = blockArgs[2];
Value cond;
if (mlir::isa<mlir::FloatType>(lhs.getType())) {
cond = nestedBuilder.create<arith::CmpFOp>(
nestedLoc, arith::CmpFPredicate::UEQ, lhs, rhs);
} else {
cond = nestedBuilder.create<arith::CmpIOp>(
nestedLoc, arith::CmpIPredicate::eq, lhs, rhs);
}
auto ifOp = nestedBuilder.create<scf::IfOp>(
nestedLoc, TypeRange{setValue.getType()}, cond, true);
{
OpBuilder::InsertionGuard guard(nestedBuilder);
nestedBuilder.setInsertionPointToEnd(&ifOp.getThenRegion().front());
nestedBuilder.create<scf::YieldOp>(nestedLoc, setValue);
}
{
OpBuilder::InsertionGuard guard(nestedBuilder);
nestedBuilder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
nestedBuilder.create<scf::YieldOp>(nestedLoc, lhs);
}
nestedBuilder.setInsertionPointToEnd(nestedBuilder.getBlock());
nestedBuilder.create<mlir::linalg::YieldOp>(nestedLoc,
ifOp.getResult(0));
});
const StringRef genericAtomicRMW = "GenericAtomicRMW";
const StringRef memSemantic = "MemSemantic";
const StringRef memSyncScope = "MemSyncScope";
auto attr = mlir::StringAttr::get(context, "cas");
linalgOp->setAttr(genericAtomicRMW, attr);
linalgOp->setAttr(memSemantic,
rewriter.getStringAttr(stringifyEnum(op.getSem())));
linalgOp->setAttr(memSyncScope,
rewriter.getStringAttr(stringifyEnum(op.getScope())));
linalgOp->setAttr("Software", rewriter.getUnitAttr());
if (op.getResult().use_empty()) {
rewriter.eraseOp(op);
}
return success();
}
LogicalResult
ScalarStoreCanonicalizer::matchAndRewrite(triton::StoreOp op,
PatternRewriter &rewriter) const {
if (!op.getValue().getType().isIntOrIndexOrFloat()) {
return rewriter.notifyMatchFailure(
op, "ScalarStoreCanonicalizer handles scalar store scene!");
}
auto ptr = op.getPtr();
auto mask = op.getMask();
auto value = op.getValue();
if (mask) {
rewriter.replaceOpWithNewOp<scf::IfOp>(op, mask,
[&](OpBuilder &b, Location loc) {
b.create<triton::StoreOp>(
loc, ptr, value, op.getCache(), op.getEvict());
b.create<scf::YieldOp>(loc);
});
return success();
}
auto ptrTy = RankedTensorType::get({(int64_t)1}, ptr.getType());
auto ptrSplat = rewriter.create<triton::SplatOp>(op.getLoc(), ptrTy, ptr);
auto valTy = RankedTensorType::get({(int64_t)1}, value.getType());
auto valSplat =
rewriter.create<triton::SplatOp>(op.getLoc(), valTy, value);
auto newStoreOp = rewriter.create<triton::StoreOp>(
op.getLoc(), ptrSplat, valSplat, op.getCache(), op.getEvict());
rewriter.replaceOp(op, newStoreOp);
return success();
}
LogicalResult
ScalarAtomicRMWCanonicalizer::matchAndRewrite(triton::AtomicRMWOp op,
PatternRewriter &rewriter) const {
if (!op.getVal().getType().isIntOrIndexOrFloat()) {
return rewriter.notifyMatchFailure(
op, "ScalarAtomicRMWCanonicalizer handles scalar atomic rmw op scene!");
}
auto ptr = op.getPtr();
auto ptrTy = RankedTensorType::get({(int64_t)1}, ptr.getType());
auto ptrSplat = rewriter.create<triton::SplatOp>(op.getLoc(), ptrTy, ptr);
auto valTy = RankedTensorType::get({(int64_t)1}, op.getVal().getType());
auto valSplat =
rewriter.create<triton::SplatOp>(op.getLoc(), valTy, op.getVal());
auto maskTy = RankedTensorType::get({(int64_t)1}, op.getMask().getType());
auto maskSplat =
rewriter.create<triton::SplatOp>(op.getLoc(), maskTy, op.getMask());
auto newAtomicOp = rewriter.create<triton::AtomicRMWOp>(
op.getLoc(), valTy, op.getAtomicRmwOp(), ptrSplat, valSplat, maskSplat,
op.getSem(), op.getScope());
auto idxZero =
rewriter.create<arith::ConstantOp>(op.getLoc(), rewriter.getIndexAttr(0));
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, newAtomicOp, ValueRange({idxZero}));
return success();
}
LogicalResult
ScalarAtomicCASCanonicalizer::matchAndRewrite(triton::AtomicCASOp op,
PatternRewriter &rewriter) const {
if (!op.getVal().getType().isIntOrIndexOrFloat() &&
!op.getCmp().getType().isIntOrIndexOrFloat()) {
return rewriter.notifyMatchFailure(
op, "ScalarAtomicCASCanonicalizer handles scalar atomic cas op scene!");
}
auto ptr = op.getPtr();
auto ptrTy = RankedTensorType::get({(int64_t)1}, ptr.getType());
auto ptrSplat = rewriter.create<triton::SplatOp>(op.getLoc(), ptrTy, ptr);
auto cmpTy = RankedTensorType::get({(int64_t)1}, op.getCmp().getType());
auto cmpSplat =
rewriter.create<triton::SplatOp>(op.getLoc(), cmpTy, op.getCmp());
auto valTy = RankedTensorType::get({(int64_t)1}, op.getVal().getType());
auto valSplat =
rewriter.create<triton::SplatOp>(op.getLoc(), valTy, op.getVal());
auto newAtomicOp = rewriter.create<triton::AtomicCASOp>(
op.getLoc(), valTy, ptrSplat, cmpSplat, valSplat, op.getSem(),
op.getScope());
auto idxZero =
rewriter.create<arith::ConstantOp>(op.getLoc(), rewriter.getIndexAttr(0));
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(op, newAtomicOp, ValueRange({idxZero}));
return success();
}
LogicalResult
AtomicMaxMinCanonicalizer::matchAndRewrite(triton::AtomicRMWOp op,
PatternRewriter &rewriter) const {
auto ptrBitcastOp = op.getPtr().getDefiningOp<triton::BitcastOp>();
auto valueBitcastOp = op.getVal().getDefiningOp<triton::BitcastOp>();
if (!ptrBitcastOp || !valueBitcastOp) {
return failure();
}
auto elementType =
dyn_cast<TensorType>(valueBitcastOp.getSrc().getType()).getElementType();
if (!isa<FloatType>(elementType)) {
return failure();
}
auto rmwOp = op.getAtomicRmwOp();
if (rmwOp == triton::RMWOp::UMAX || rmwOp == triton::RMWOp::UMIN) {
if (op.getResult().use_empty()) {
rewriter.eraseOp(op);
return success();
}
return failure();
}
if (rmwOp != triton::RMWOp::MAX && rmwOp != triton::RMWOp::MIN) {
return failure();
}
Value originalMask = op.getMask();
if (auto andOp = originalMask.getDefiningOp<arith::AndIOp>())
originalMask = andOp.getLhs();
else if (auto cmpOp = originalMask.getDefiningOp<arith::CmpFOp>()) {
if (cmpOp.getPredicate() != mlir::arith::CmpFPredicate::OGE ||
!matchPattern(cmpOp.getRhs(),
m_PosZeroFloat()))
return op->emitError("Illegal mask for atomicrmwOp of float type");
originalMask = rewriter.create<arith::ConstantOp>(
op->getLoc(),
DenseElementsAttr::get(
cast<ShapedType>(originalMask.getType()), true));
} else
return op->emitError("Illegal mask for atomicrmwOp of float type");
auto originAtomicOp = rewriter.create<triton::AtomicRMWOp>(
op.getLoc(), valueBitcastOp.getSrc().getType(), op.getAtomicRmwOp(),
ptrBitcastOp.getSrc(), valueBitcastOp.getSrc(), originalMask, op.getSem(),
op.getScope());
if (!op.getResult().use_empty()) {
for (OpOperand &use : op->getUses()) {
auto selectOp = dyn_cast<arith::SelectOp>(use.getOwner());
if (!selectOp)
continue;
for (OpOperand &selectUse : selectOp->getUses()) {
if (auto bitcastOp =
dyn_cast<triton::BitcastOp>(selectUse.getOwner())) {
bitcastOp.getResult().replaceAllUsesWith(originAtomicOp);
}
}
}
rewriter.replaceOp(op, originAtomicOp);
} else {
rewriter.eraseOp(op);
}
return success();
}
StoreConverter::StoreConverter(MLIRContext *context)
: OpConversionPattern<triton::StoreOp>(context) {}
LogicalResult
StoreConverter::matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
auto mask = op.getMask();
auto loc = op.getLoc();
auto ptr = adaptor.getPtr();
auto val = adaptor.getValue();
auto boundaryCheck = op.getBoundaryCheck();
if (!boundaryCheck.empty()) {
auto makeTensorPtrOp = op.getPtr().getDefiningOp<triton::MakeTensorPtrOp>();
auto boundarySizes = mlir::ConverterUtils::getBoundarySizes(
boundaryCheck, ptr, loc, rewriter);
SmallVector<OpFoldResult> srcOffsets;
SmallVector<OpFoldResult> dstOffsets(boundarySizes.size(), rewriter.getIndexAttr(0));
if (makeTensorPtrOp) {
auto zeroVal = rewriter.createOrFold<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
for (auto [idx, offVal] : llvm::enumerate(makeTensorPtrOp.getOffsets())) {
if (llvm::find(boundaryCheck, idx) == boundaryCheck.end()) {
srcOffsets.push_back(dstOffsets[idx]);
continue;
}
Value offset = rewriter.createOrFold<arith::SubIOp>(loc, zeroVal, offVal);
Value size = getValueOrCreateConstantIndexOp(rewriter, loc, boundarySizes[idx]);
offset = rewriter.createOrFold<arith::MaxSIOp>(loc, offset, zeroVal);
offset = rewriter.createOrFold<arith::IndexCastOp>(loc, rewriter.getIndexType(), offset);
OpFoldResult ofr;
if (auto constOp = offset.getDefiningOp<arith::ConstantOp>()) {
ofr = constOp.getValue();
} else {
ofr = offset;
}
ofr = minOpFoldResult(ofr, size, loc, rewriter);
boundarySizes[idx] = subOpFoldResult(size, ofr, loc, rewriter);
srcOffsets.push_back(ofr);
}
} else {
srcOffsets = dstOffsets;
}
auto srcSlice = mlir::ConverterUtils::makeExtractSliceOp(
val, srcOffsets, boundarySizes, loc, rewriter);
auto dstSubview = mlir::ConverterUtils::makeSubViewOp(
ptr, dstOffsets, boundarySizes, loc, rewriter);
auto storeOp = rewriter.create<bufferization::MaterializeInDestinationOp>(
loc, srcSlice, dstSubview);
storeOp.setWritable(true);
rewriter.eraseOp(op);
return success();
}
if (!mask) {
auto storeOp = rewriter.create<bufferization::MaterializeInDestinationOp>(
loc, val, ptr);
storeOp.setWritable(true);
rewriter.eraseOp(op);
return success();
}
MaskState mstate;
auto isContMask = mstate.parse(mask, loc, rewriter);
if (isContMask.failed()) {
return failure();
}
LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(op) << "\n"; });
auto srcSlice = mstate.getExtractSlice(val, loc, rewriter);
auto dstSubview = mstate.getSubview(ptr, loc, rewriter);
auto storeOp = rewriter.create<bufferization::MaterializeInDestinationOp>(
loc, srcSlice, dstSubview);
storeOp.setWritable(true);
rewriter.eraseOp(op);
return success();
}
bool ReinterpretCastStrideCanonicalizer::hasFixableZeroStride(memref::ReinterpretCastOp op)
{
auto staticSizes = op.getStaticSizes();
auto staticStrides = op.getStaticStrides();
auto dynamicStrides = op.getStrides();
if (staticSizes.size() != staticStrides.size())
return false;
if (llvm::any_of(staticSizes, ShapedType::isDynamic))
return false;
unsigned dynStrideIdx = 0;
for (unsigned i = 0; i < staticStrides.size(); ++i) {
if (!ShapedType::isDynamic(staticStrides[i]))
continue;
if (dynStrideIdx >= dynamicStrides.size())
return false;
Value st = dynamicStrides[dynStrideIdx];
dynStrideIdx++;
if (staticSizes[i] == 1 && mlir::isZero(OpFoldResult(st)))
return true;
}
return false;
}
LogicalResult ReinterpretCastStrideCanonicalizer::matchAndRewrite(
memref::ReinterpretCastOp op, PatternRewriter &rewriter) const
{
if (!hasFixableZeroStride(op))
return failure();
auto staticSizes = op.getStaticSizes();
auto staticStrides = op.getStaticStrides();
auto dynamicStrides = op.getStrides();
SmallVector<Value> newDynamicStrides;
newDynamicStrides.reserve(dynamicStrides.size());
unsigned dynStrideIdx = 0;
bool changed = false;
Value c1 = rewriter.create<arith::ConstantOp>(op.getLoc(), rewriter.getIndexAttr(1));
for (unsigned i = 0, e = staticStrides.size(); i < e; ++i) {
if (!ShapedType::isDynamic(staticStrides[i]))
continue;
if (dynStrideIdx >= dynamicStrides.size())
return failure();
Value oldStride = dynamicStrides[dynStrideIdx];
dynStrideIdx++;
if (staticSizes[i] == 1 && mlir::isZero(OpFoldResult(oldStride))) {
newDynamicStrides.push_back(c1);
changed = true;
} else {
newDynamicStrides.push_back(oldStride);
}
}
if (dynStrideIdx != dynamicStrides.size())
return failure();
if (!changed)
return failure();
auto newReinterpretCast = rewriter.create<memref::ReinterpretCastOp>(
op.getLoc(),
cast<MemRefType>(op.getResult().getType()),
op.getSource(),
op.getOffsets(),
op.getSizes(),
newDynamicStrides,
op.getStaticOffsets(),
op.getStaticSizes(),
op.getStaticStrides());
rewriter.replaceOp(op, newReinterpretCast.getResult());
return success();
}
}