* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* Copyright (c) Microsoft Corporation.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "TritonToStructured/MaskAnalysis.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <optional>
#include <queue>
#include <string>
#include <utility>
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/LogicalResult.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "TritonToStructured/PtrAnalysis.h"
#include "Utils/Utils.h"
#define DEBUG_TYPE "triton-to-structured-mask-analysis"
namespace TritonToStructured {
using namespace mlir;
using namespace triton;
bool dimInfo::setType(arith::CmpIPredicate Type) {
LLVM_DEBUG({
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "Setting compare type for dimIndex " << dimIndex << "\n";
llvm::dbgs() << "Type: " << Type << "\n";
llvm::dbgs() << "----------------------------------------------\n";
});
switch (Type) {
case arith::CmpIPredicate::slt:
this->currentType = dimInfo::CompareType::slt;
break;
case arith::CmpIPredicate::ult:
this->currentType = dimInfo::CompareType::ult;
break;
case arith::CmpIPredicate::sge:
this->currentType = dimInfo::CompareType::sge;
break;
case arith::CmpIPredicate::uge:
this->currentType = dimInfo::CompareType::uge;
break;
default:
return false;
}
return true;
}
bool dimInfo::compareTypeIsLess() const {
return this->currentType == dimInfo::CompareType::slt ||
this->currentType == dimInfo::CompareType::ult;
}
void dimInfo::dump() const {
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "MaskDimInfo: \n";
llvm::dbgs() << "offset = " << offset << "\n";
llvm::dbgs() << "shape = " << shape << "\n";
llvm::dbgs() << "rhs = " << rhs << "\n";
llvm::dbgs() << "isLessMode = " << compareTypeIsLess() << "\n";
llvm::dbgs() << "hasBroadCast = " << hasBroadCast << "\n";
llvm::dbgs() << "----------------------------------------------\n";
};
void MaskState::dump() const {
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "MaskState :\n";
llvm::dbgs() << "scalar = " << scalar << "\n";
llvm::dbgs() << "stateInfo.size = " << stateInfo.size() << "\n";
for (auto info : stateInfo)
info.dump();
llvm::dbgs() << "----------------------------------------------\n";
};
LogicalResult MaskState::parse(Value operand, const Location loc,
OpBuilder &builder) {
if (isa<IntegerType>(operand.getType()) && operand.getType().getIntOrFloatBitWidth() != 1) {
return this->parseIntScalar(operand, loc, builder);
} else if (auto op = operand.getDefiningOp<arith::ConstantOp>()) {
return this->parseConstant(op, loc, builder);
} else if (auto op = operand.getDefiningOp<arith::AddIOp>()) {
return this->parseAdd(op, loc, builder);
} else if (auto op = operand.getDefiningOp<arith::AndIOp>()) {
return this->parseAnd(op, loc, builder);
} else if (auto op = operand.getDefiningOp<arith::CmpIOp>()) {
return this->parseCmp(op, loc, builder);
} else if (auto op = operand.getDefiningOp<triton::MakeRangeOp>()) {
return this->parseMakeRange(op, loc, builder);
} else if (auto op = operand.getDefiningOp<triton::BroadcastOp>()) {
return this->parseBroadcast(op, loc, builder);
} else if (auto op = operand.getDefiningOp<triton::SplatOp>()) {
return this->parseSplat(op, loc, builder);
} else if (auto op = operand.getDefiningOp<triton::ExpandDimsOp>()) {
return this->parseExpandDims(op, loc, builder);
} else if (auto op = operand.getDefiningOp<arith::ExtSIOp>()) {
return this->parseExtSI(op, loc, builder);
} else if (auto op = operand.getDefiningOp<arith::RemSIOp>()) {
return this->parseRem(op, loc, builder);
} else if (auto op = operand.getDefiningOp<arith::DivSIOp>()) {
return this->parseDiv(op, loc, builder);
}
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc)
<< "MaskAnalysis: compare operand produced by an "
"unsupported operation\n";
});
return failure();
}
LogicalResult MaskState::parseConstant(arith::ConstantOp constOp,
const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
constOp.emitError(
"MaskAnalysis: MaskState should be empty when visiting constant");
});
return failure();
}
if (auto intType = dyn_cast<IntegerType>(constOp.getType())) {
if (intType.getWidth() == 1) {
LLVM_DEBUG({
constOp.emitWarning("MaskAnalysis: Unsupported constant for int1");
});
return failure();
}
}
if (isa<DenseElementsAttr>(constOp.getValue())) {
auto attr = cast<DenseElementsAttr>(constOp.getValue());
auto elementType = attr.getElementType();
if (!attr.isSplat() || !isa<IntegerType>(elementType)) {
LLVM_DEBUG({
constOp.emitError(
"MaskAnalysis: only support splat integer constant");
});
return failure();
}
auto values = attr.getValues<IntegerAttr>();
auto value = values[0].getValue();
auto constAttr = builder.getIndexAttr(value.getSExtValue());
auto op = arith::ConstantOp::materialize(builder, constAttr,
builder.getIndexType(), loc);
this->scalar = op.getValue();
} else {
auto value = cast<IntegerAttr>(constOp.getValue()).getInt();
this->scalar = builder.getIndexAttr(value);
}
return success();
}
LogicalResult MaskState::parseIntScalar(Value scalar, const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
InFlightDiagnostic diag = emitError(loc)
<< "MaskAnalysis: MaskState should be empty when "
"visiting integer scalar";
});
return failure();
}
auto castOp =
builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), scalar);
this->scalar = castOp.getResult();
return success();
}
LogicalResult MaskState::parseMakeRange(triton::MakeRangeOp rangeOp,
const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
rangeOp.emitError("MaskAnalysis: MaskState should be empty when visiting make_range");
});
return failure();
}
auto shape = cast<ShapedType>(rangeOp.getType()).getShape();
auto start = rangeOp.getStart();
auto end = rangeOp.getEnd();
auto stride = (end - start + shape[0] - 1) / shape[0];
if (stride != 1) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitError(loc)
<< "stride must be 1 for make_range whose result is used "
"as load or store masks";
});
return failure();
}
stateInfo.emplace_back(builder.getIndexAttr(start),
builder.getIndexAttr(shape[0]));
return success();
}
LogicalResult MaskState::parseExtSI(arith::ExtSIOp op, const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
op->emitError("MaskAnalysis: MaskState should be empty when visiting extsi");
});
return failure();
}
return parse(op.getIn(), loc, builder);
}
LogicalResult MaskState::parseSplat(triton::SplatOp splatOp, const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
splatOp.emitError("MaskAnalysis: MaskState should be empty when visiting splat");
});
return failure();
}
auto src = splatOp.getSrc();
auto dst = splatOp.getResult();
auto dstShape = cast<ShapedType>(dst.getType()).getShape();
if (!isa<IntegerType>(src.getType())) {
LLVM_DEBUG({
splatOp.emitError()
<< "splat source must be an integer scalar for load/store masks";
});
return failure();
}
if (failed(this->parse(src, loc, builder)))
return failure();
if (stateInfo.size() > 1 || (stateInfo.size() == 1 && !isOne(stateInfo.back().shape))) {
LLVM_DEBUG({
splatOp.emitError()
<< "splat from a non-scalar source is not supported, unless it's state size and shape are 1";
});
return failure();
}
auto zeroAttr = builder.getIndexAttr(0);
SmallVector<dimInfo> newStateInfo;
for (auto [i, shape] : llvm::enumerate(dstShape)) {
auto shapeAttr = builder.getIndexAttr(shape);
newStateInfo.emplace_back(zeroAttr, shapeAttr, i, true);
if (!stateInfo.empty() && shape == 1) {
newStateInfo.back() = stateInfo.back();
newStateInfo.back().dimIndex = i;
}
}
this->stateInfo = newStateInfo;
return success();
}
LogicalResult MaskState::parseExpandDims(triton::ExpandDimsOp expandDimsOp,
const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
expandDimsOp.emitError(
"MaskAnalysis: MaskState should be empty when visiting expand_dims");
});
return failure();
}
auto zeroAttr = builder.getIndexAttr(0);
auto defaultShape = builder.getIndexAttr(1);
if (failed(this->parse(expandDimsOp.getSrc(), loc, builder)))
return failure();
auto dstShape =
cast<ShapedType>(expandDimsOp.getResult().getType()).getShape();
auto axis = expandDimsOp.getAxis();
if (dstShape[axis] != 1) {
LLVM_DEBUG({
expandDimsOp.emitError(
"MaskAnalysis: unexpected dimension size in expand_dims");
});
return failure();
}
size_t insertPos = 0;
for (auto &info : stateInfo) {
if (info.dimIndex >= axis)
++info.dimIndex;
if (info.dimIndex < axis)
++insertPos;
}
dimInfo insertInfo(zeroAttr, defaultShape, axis, true);
stateInfo.insert(stateInfo.begin() + insertPos, insertInfo);
return success();
}
LogicalResult MaskState::parseAdd(arith::AddIOp addOp, const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
addOp.emitError("MaskAnalysis: MaskState should be empty when visiting add");
});
return failure();
}
MaskState lhsState;
if (failed(lhsState.parse(addOp.getLhs(), loc, builder)))
return failure();
MaskState rhsState;
if (failed(rhsState.parse(addOp.getRhs(), loc, builder)))
return failure();
return this->addStates(lhsState, rhsState, loc, builder);
}
LogicalResult MaskState::addStates(const MaskState &lhsState,
const MaskState &rhsState, Location loc,
OpBuilder &builder) {
if (lhsState.scalar && rhsState.scalar) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc) << "Unexpected case where both lhs and rhs are scalars";
});
return failure();
}
if (!lhsState.scalar && !rhsState.scalar) {
LLVM_DEBUG({
InFlightDiagnostic diag =
emitWarning(loc)
<< "Unsupported scenario where neither lhs nor rhs is a scalar";
});
return failure();
}
if (lhsState.scalar)
return addStateScalar(rhsState, lhsState.scalar, loc, builder);
else
return addStateScalar(lhsState, rhsState.scalar, loc, builder);
}
LogicalResult MaskState::addStateScalar(const MaskState &state,
const OpFoldResult scalar, Location loc,
OpBuilder &builder) {
for (auto info : state.stateInfo) {
info.offset = addOpFoldResult(info.offset, scalar, loc, builder);
this->stateInfo.emplace_back(info);
}
return success();
}
LogicalResult MaskState::parseBroadcast(triton::BroadcastOp broadcastOp,
const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
broadcastOp.emitError(
"MaskAnalysis: MaskState should be empty when visiting broadcast");
});
return failure();
}
LLVM_DEBUG({
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "Parsing BROADCAST operation: " << broadcastOp << "\n";
});
auto src = broadcastOp.getSrc();
auto dst = broadcastOp.getResult();
if (!isa<ShapedType>(dst.getType())) {
LLVM_DEBUG({
broadcastOp.emitError(
"MaskAnalysis: broadcast dst should be a shaped type");
});
return failure();
}
auto srcShape = cast<ShapedType>(src.getType()).getShape();
auto dstShape = cast<ShapedType>(dst.getType()).getShape();
if (srcShape.size() != dstShape.size()) {
LLVM_DEBUG({
broadcastOp.emitError(
"MaskAnalysis: broadcast src and dst should have the same rank");
});
return failure();
}
if (failed(parse(src, loc, builder)))
return failure();
LLVM_DEBUG({
llvm::dbgs() << "Before BROADCAST MaskState: \n";
this->dump();
});
for (size_t i = 0; i < srcShape.size(); ++i) {
if (srcShape[i] == dstShape[i])
continue;
else if (srcShape[i] < dstShape[i] && srcShape[i] == 1) {
for (auto &info : stateInfo) {
if (info.dimIndex != i)
continue;
info.shape = builder.getIndexAttr(dstShape[i]);
info.hasBroadCast = true;
}
} else {
LLVM_DEBUG({
broadcastOp.emitError(
"MaskAnalysis: unexpected dimensions used in broadcast");
});
return failure();
}
}
LLVM_DEBUG({
llvm::dbgs() << "After BROADCAST MaskState: \n";
this->dump();
llvm::dbgs() << "----------------------------------------------\n";
});
return success();
}
LogicalResult MaskState::parseCmp(arith::CmpIOp cmpOp, const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
cmpOp.emitError("MaskAnalysis: MaskState should be empty when visiting cmpi");
});
return failure();
}
LLVM_DEBUG({
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "Parsing CMP operation: " << cmpOp << "\n";
});
if (isa<IntegerType>(cmpOp.getLhs().getType()) &&
(cmpOp.getLhs().getDefiningOp<arith::CmpIOp>() ||
cmpOp.getRhs().getDefiningOp<arith::CmpIOp>())) {
LLVM_DEBUG({
cmpOp.emitWarning("MaskAnalysis: Unsupported nested cmpi scenario for int1");
});
return failure();
}
MaskState lhsState;
if (failed(lhsState.parse(cmpOp.getLhs(), loc, builder)))
return failure();
MaskState rhsState;
if (failed(rhsState.parse(cmpOp.getRhs(), loc, builder)))
return failure();
if (lhsState.scalar) {
if (lhsState.stateInfo.empty()) {
lhsState.stateInfo.emplace_back(
builder.getIndexAttr(0), builder.getIndexAttr(1));
}
for (auto &info : lhsState.stateInfo) {
if (isOne(info.shape)) {
info.offset = lhsState.scalar;
}
}
}
if (!rhsState.scalar) {
LLVM_DEBUG({
cmpOp.emitWarning("MaskAnalysis: Unsupported cmpi scenario");
});
return failure();
}
LLVM_DEBUG({
llvm::dbgs() << "LHS MaskState: \n";
lhsState.dump();
llvm::dbgs() << "RHS MaskState: \n";
rhsState.dump();
llvm::dbgs() << "----------------------------------------------\n";
});
auto cmpType = cmpOp.getPredicate();
for (auto &info : lhsState.stateInfo) {
if (info.hasBroadCast)
continue;
if (!info.setType(cmpType)) {
LLVM_DEBUG({
cmpOp.emitWarning("MaskAnalysis: Unsupported cmpi type");
});
return failure();
}
info.rhs = rhsState.scalar;
}
this->stateInfo = lhsState.stateInfo;
LLVM_DEBUG({
llvm::dbgs() << "After CMP MaskState: \n";
this->dump();
llvm::dbgs() << "----------------------------------------------\n";
});
return success();
}
LogicalResult MaskState::parseRem(arith::RemSIOp remOp, const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
remOp.emitError("MaskAnalysis: MaskState should be empty when visiting REMSI");
});
return failure();
}
MaskState lhsState;
if (failed(lhsState.parse(remOp.getLhs(), loc, builder)))
return failure();
MaskState rhsState;
if (failed(rhsState.parse(remOp.getRhs(), loc, builder)))
return failure();
if (lhsState.scalar || !rhsState.scalar) {
LLVM_DEBUG({
remOp.emitRemark("MaskAnalysis: Unsupported REMSI scenario");
});
return failure();
}
auto divisorAttr = rhsState.scalar;
if (!getIntAttr(divisorAttr).has_value()) {
LLVM_DEBUG({
remOp.emitError("MaskAnalysis: do not support dynamic divisor in REMSI.");
});
return failure();
}
SmallVector<dimInfo> newStateInfo;
auto zeroAttr = builder.getIndexAttr(0);
for (auto info : lhsState.stateInfo) {
if (info.hasBroadCast) {
newStateInfo.emplace_back(info);
continue;
}
if (!isMultiple(divisorAttr, info.shape) && !isMultiple(info.shape, divisorAttr)) {
LLVM_DEBUG({
remOp.emitError("MaskAnalysis: do not support dynamic stride before REMSI.");
});
return failure();
}
auto contiguousSize = minOpFoldResult(divisorAttr, info.shape, loc, builder);
auto nonContiguousSize = divOpFoldResult(info.shape, contiguousSize, loc, builder);
auto staticNonContiguousSize = getIntAttr(nonContiguousSize);
if (!staticNonContiguousSize.has_value()) {
LLVM_DEBUG({
remOp.emitError("MaskAnalysis: do not support dynamic size before REMSI.");
});
return failure();
}
if (staticNonContiguousSize.value() != 0)
newStateInfo.emplace_back(zeroAttr, nonContiguousSize, info.dimIndex, true);
auto newOffset = remOpFoldResult(info.offset, divisorAttr, loc, builder);
newStateInfo.emplace_back(newOffset, nonContiguousSize, info.dimIndex);
}
this->stateInfo = newStateInfo;
return success();
}
LogicalResult MaskState::parseDiv(arith::DivSIOp divOp, const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
divOp.emitError("MaskAnalysis: MaskState should be empty when visiting DIVSI");
});
return failure();
}
LLVM_DEBUG({
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "Parsing DIV operation: " << divOp << "\n";
});
MaskState lhsState;
if (failed(lhsState.parse(divOp.getLhs(), loc, builder)))
return failure();
MaskState rhsState;
if (failed(rhsState.parse(divOp.getRhs(), loc, builder)))
return failure();
LLVM_DEBUG({
llvm::dbgs() << "LHS MaskState: \n";
lhsState.dump();
llvm::dbgs() << "RHS MaskState: \n";
rhsState.dump();
llvm::dbgs() << "----------------------------------------------\n";
});
if (lhsState.scalar || !rhsState.scalar) {
LLVM_DEBUG({
divOp.emitRemark("MaskAnalysis: Unsupported DIVSI scenario");
});
return failure();
}
auto divisorAttr = rhsState.scalar;
if (!getIntAttr(divisorAttr).has_value()) {
LLVM_DEBUG({
divOp.emitError("MaskAnalysis: do not support dynamix divisor in DIVSI.");
});
return failure();
}
SmallVector<dimInfo> newStateInfo;
auto zeroAttr = builder.getIndexAttr(0);
for (auto info : lhsState.stateInfo) {
if (info.hasBroadCast) {
newStateInfo.emplace_back(info);
continue;
}
if (!isMultiple(divisorAttr, info.shape) && !isMultiple(info.shape, divisorAttr)) {
LLVM_DEBUG({
divOp.emitError("MaskAnalysis: do not support dynamix stride before DIVSI.");
});
return failure();
}
auto nonContiguousSize = minOpFoldResult(divisorAttr, info.shape, loc, builder);
auto contiguousSize = divOpFoldResult(info.shape, nonContiguousSize, loc, builder);
auto staticContiguousSize = getIntAttr(contiguousSize);
if (!staticContiguousSize.has_value()) {
LLVM_DEBUG({
divOp.emitError("MaskAnalysis: do not support dynamix size before DIVSI.");
});
return failure();
}
if (staticContiguousSize.value() != 0) {
auto newOffset = divOpFoldResult(info.offset, divisorAttr, loc, builder);
newStateInfo.emplace_back(newOffset, contiguousSize, info.dimIndex);
}
newStateInfo.emplace_back(zeroAttr, nonContiguousSize, info.dimIndex, true);
}
this->stateInfo = newStateInfo;
LLVM_DEBUG({
llvm::dbgs() << "After DIV MaskState: \n";
this->dump();
llvm::dbgs() << "----------------------------------------------\n";
});
return success();
}
LogicalResult MaskState::parseAnd(arith::AndIOp andOp, const Location loc,
OpBuilder &builder) {
if (!this->isEmpty()) {
LLVM_DEBUG({
andOp.emitError("MaskAnalysis: MaskState should be empty when visiting and");
});
return failure();
}
auto zeroAttr = builder.getIndexAttr(0);
LLVM_DEBUG({
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "Parsing AND operation: " << andOp << "\n";
});
if (isa<IntegerType>(andOp.getLhs().getType())) {
LLVM_DEBUG({
andOp.emitWarning("MaskAnalysis: Unsupported andi scenario for int1");
});
return failure();
}
MaskState lhsState;
if (failed(lhsState.parse(andOp.getLhs(), loc, builder)))
return failure();
MaskState rhsState;
if (failed(rhsState.parse(andOp.getRhs(), loc, builder)))
return failure();
LLVM_DEBUG({
llvm::dbgs() << "LHS MaskState: \n";
lhsState.dump();
llvm::dbgs() << "RHS MaskState: \n";
rhsState.dump();
llvm::dbgs() << "----------------------------------------------\n";
});
SmallVector<dimInfo> newStateInfo;
auto lIt = lhsState.stateInfo.begin();
auto rIt = rhsState.stateInfo.begin();
while (lIt != lhsState.stateInfo.end() && rIt != rhsState.stateInfo.end()) {
if (lIt->dimIndex != rIt->dimIndex) {
auto newInfo = lIt->dimIndex < rIt->dimIndex ? *lIt++ : *rIt++;
newStateInfo.emplace_back(newInfo);
continue;
}
if (!isMultiple(lIt->shape, rIt->shape) &&
!isMultiple(rIt->shape, lIt->shape)) {
LLVM_DEBUG({
llvm::dbgs() << "LHS MaskState: \n";
lhsState.dump();
llvm::dbgs() << "RHS MaskState: \n";
rhsState.dump();
llvm::dbgs() << "----------------------------------------------\n";
});
LLVM_DEBUG({
andOp.emitError("MaskAnalysis: the add operation have incompatible sizes");
});
return failure();
}
dimInfo newInfo;
newInfo.dimIndex = lIt->dimIndex;
newInfo.shape = minOpFoldResult(lIt->shape, rIt->shape, loc, builder);
if ((isLess(newInfo.shape, lIt->shape) && !lIt->hasBroadCast ||
isLess(newInfo.shape, rIt->shape) && !rIt->hasBroadCast)) {
LLVM_DEBUG({
llvm::dbgs() << "LHS MaskState: \n";
lhsState.dump();
llvm::dbgs() << "RHS MaskState: \n";
rhsState.dump();
llvm::dbgs() << "----------------------------------------------\n";
});
LLVM_DEBUG({
andOp.emitError("MaskAnalysis: the add operation have incompatible sizes."
"Valid dimensions are split.");
});
return failure();
}
newInfo.currentType = lIt->hasBroadCast ? rIt->currentType : lIt->currentType;
if (lIt->currentType != dimInfo::CompareType::deafaultType &&
rIt->currentType != dimInfo::CompareType::deafaultType &&
lIt->currentType != rIt->currentType) {
LLVM_DEBUG({
andOp.emitError("MaskAnalysis: do not suppport different compare mode within"
"the same dimension.");
});
return failure();
}
if (lIt->hasBroadCast) {
newInfo.offset = rIt->offset;
newInfo.rhs = rIt->rhs;
newInfo.hasBroadCast = rIt->hasBroadCast;
} else if (rIt->hasBroadCast) {
newInfo.offset = lIt->offset;
newInfo.rhs = lIt->rhs;
newInfo.hasBroadCast = lIt->hasBroadCast;
} else {
LLVM_DEBUG({
andOp.emitError("MaskAnalysis: do not suppport "
"and in the same dimension.");
});
return failure();
}
newStateInfo.emplace_back(newInfo);
if (isEqual(lIt->shape, newInfo.shape))
++lIt;
else
lIt->shape = divOpFoldResult(lIt->shape, newInfo.shape, loc, builder);
if (isEqual(rIt->shape, newInfo.shape))
++rIt;
else
rIt->shape = divOpFoldResult(rIt->shape, newInfo.shape, loc, builder);
}
while (rIt != rhsState.stateInfo.end()) {
newStateInfo.push_back(*rIt++);
}
while (lIt != lhsState.stateInfo.end()) {
newStateInfo.push_back(*lIt++);
}
this->stateInfo = newStateInfo;
LLVM_DEBUG({
llvm::dbgs() << "After AND MaskState: \n";
this->dump();
llvm::dbgs() << "----------------------------------------------\n";
});
return success();
}
LogicalResult MaskState::analysisMask(Value operand) {
auto op = operand.getDefiningOp();
if (!op) {
return failure();
}
auto loc = op->getLoc();
OpBuilder builder(op);
LLVM_DEBUG({
llvm::dbgs() << "----------------------------------------------\n";
llvm::dbgs() << "Analyzing mask: " << operand << "\n";
});
if (this->parse(operand, loc, builder).failed() || this->isEmpty()) {
return failure();
}
LLVM_DEBUG({
llvm::dbgs() << "Mask analysis result: \n";
this->dump();
llvm::dbgs() << "MaskAnalysis: successfully analyzed mask.\n";
llvm::dbgs() << "----------------------------------------------\n";
});
return success();
}
Value MaskState::createNewMask(const Location loc, OpBuilder &builder) {
if (isEmpty())
return nullptr;
SmallVector<int64_t> shape;
for (auto info : stateInfo) {
auto staticShape = getIntAttr(info.shape);
if (!staticShape.has_value()) {
LLVM_DEBUG({
InFlightDiagnostic diag = emitError(loc)
<< "MaskAnalysis: dynamic shape is not supported in mask "
"generation\n";
});
return nullptr;
}
shape.emplace_back(staticShape.value());
}
SmallVector<Value> cacheResults;
auto maskShape = RankedTensorType::get(shape, builder.getI1Type());
auto createRhsValue = [&](OpFoldResult rhs) -> Value {
if (auto rhsInt = getIntAttr(rhs)) {
auto rhsAttr =
builder.getI32IntegerAttr(static_cast<int32_t>(rhsInt.value()));
return builder.create<arith::ConstantOp>(loc, rhsAttr).getResult();
}
Value rhsValue = dyn_cast<Value>(rhs);
if (rhsValue.getType().isIndex()) {
rhsValue = builder.create<arith::IndexCastOp>(loc, builder.getI32Type(),
rhsValue);
}
return rhsValue;
};
for (size_t i = 0; i < stateInfo.size(); ++i) {
auto info = stateInfo[i];
if (info.hasBroadCast) {
continue;
}
auto indexI32RowType = RankedTensorType::get(shape[i], builder.getI32Type());
Value newMask = builder.create<triton::MakeRangeOp>(loc, indexI32RowType, 0,
shape[i]);
Value newOffset = createRhsValue(info.offset);
if (newOffset.getType().isIndex()) {
newOffset = builder.create<arith::IndexCastOp>(loc, builder.getI32Type(),
newOffset);
}
Value splatRhs = builder.create<triton::SplatOp>(
loc, indexI32RowType, newOffset);
newMask = builder.create<arith::AddIOp>(
loc, newMask,
splatRhs);
auto rhsValue = createRhsValue(info.rhs);
auto splatOp =
builder.create<triton::SplatOp>(loc, indexI32RowType, rhsValue);
if (info.currentType == dimInfo::CompareType::deafaultType) {
LLVM_DEBUG({
InFlightDiagnostic diag = emitError(loc)
<< "MaskAnalysis: cannot generate mask when "
"compare type is not set\n";
});
return nullptr;
}
auto cmpOp = builder.create<arith::CmpIOp>(
loc, info.compareTypeIsLess() ? arith::CmpIPredicate::slt
: arith::CmpIPredicate::sge,
newMask, splatOp.getResult());
auto expandValue = cmpOp.getResult();
for (size_t j = 0; j < stateInfo.size(); ++j) {
if (j == i)
continue;
expandValue = builder.create<triton::ExpandDimsOp>(loc, expandValue, j);
}
auto broadcastValue =
builder.create<triton::BroadcastOp>(loc, maskShape, expandValue);
cacheResults.push_back(broadcastValue);
}
if (cacheResults.empty()) {
LLVM_DEBUG({
InFlightDiagnostic diag = emitWarning(loc)
<< "MaskAnalysis: cannot generate mask when all "
"dimensions are broadcasted";
});
return nullptr;
}
newMask = cacheResults[0];
for (size_t i = 1; i < cacheResults.size(); ++i) {
newMask = builder.create<arith::AndIOp>(loc, newMask, cacheResults[i]);
}
return newMask;
}
}