* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "ascend/include/TritonToLinalg/UseAnalysis.h"
#include "ascend/include/Utils/Utils.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "bishengir/Dialect/HIVM/IR/HIVM.h"
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace triton;
using namespace dataflow;
#define DEBUG_TYPE "triton-use-analysis"
std::string stringifyUseType(UseType useTy) {
std::string ret;
if (useTy == UseType::MetaUse) {
ret = "MetaUse";
} else if (useTy == UseType::DataUse) {
ret = "DataUse";
} else if (useTy == UseType::MixUse) {
ret = "MixUse";
} else if (useTy == UseType::Undefined) {
ret = "Undefined";
}
return ret;
}
#if LLVM_VERSION_MAJOR >= 20
LogicalResult
triton::UseAnalysis::visitOperation(Operation *op, ArrayRef<UseInfo *> operands,
ArrayRef<const UseInfo *> results) {
#else
void triton::UseAnalysis::visitOperation(Operation *op,
ArrayRef<UseInfo *> operands,
ArrayRef<const UseInfo *> results) {
#endif
if (op->getResults().size() == 1) {
auto resultType = dyn_cast<ShapedType>(op->getResult(0).getType());
if (resultType && isa<triton::PointerType>(resultType.getElementType())) {
for (auto opnd : operands) {
propagateUse(opnd, UseType::MetaUse);
}
}
}
TypeSwitch<Operation *>(op)
.Case<triton::LoadOp>([&](auto load) {
propagateUse(operands[0], UseType::MetaUse);
auto mask = load.getMask();
auto other = load.getOther();
if (mask) {
assert(mask != other && "mask and other cannot be the same");
propagateUse(operands[1], UseType::MetaUse);
}
if (other) {
propagateUse(operands[2], UseType::MetaUse);
}
})
.Case<triton::PrintOp>(
[&](auto print) { propagateUse(operands[0], UseType::DataUse); })
.Case<triton::AssertOp>(
[&](auto assert) { propagateUse(operands[0], UseType::DataUse); })
.Case<triton::StoreOp>([&](auto store) {
propagateUse(operands[0], UseType::MetaUse);
propagateUse(operands[1], UseType::DataUse);
auto value = store.getValue();
auto mask = store.getMask();
if (mask) {
assert(mask != value && "mask and data cannot be the same");
propagateUse(operands[2], UseType::MetaUse);
}
})
.Case<triton::ascend::IndirectStoreOp>([&](auto store) {
propagateUse(operands[0], UseType::MetaUse);
propagateUse(operands[1], UseType::MetaUse);
propagateUse(operands[2], UseType::DataUse);
auto value = store.getValue();
auto mask = store.getMask();
if (mask) {
assert(mask != value && "mask and data cannot be the same");
propagateUse(operands[3], UseType::MetaUse);
}
})
.Case<triton::AtomicRMWOp>([&](auto atomicOp) {
propagateUse(operands[0], UseType::MixUse);
propagateUse(operands[1], UseType::DataUse);
auto value = atomicOp.getVal();
auto mask = atomicOp.getMask();
if (mask) {
assert(mask != value && "mask and data cannot be the same");
propagateUse(operands[2], UseType::MetaUse);
}
})
.Case<triton::AtomicCASOp>([&](auto atomicOp) {
propagateUse(operands[0], UseType::MetaUse);
propagateUse(operands[1], UseType::DataUse);
propagateUse(operands[2], UseType::DataUse);
auto value = atomicOp.getVal();
})
.Case<triton::DotOp>([&](auto dot) {
propagateResults(operands[0], results);
propagateResults(operands[1], results);
auto opc = dot.getC();
triton::SplatOp splat;
if (opc) {
splat = opc.template getDefiningOp<triton::SplatOp>();
}
if (opc && splat && splat.getSrc().getDefiningOp<arith::ConstantOp>()) {
propagateUse(operands[2], UseType::MetaUse);
} else {
propagateUse(operands[2], UseType::DataUse);
}
})
.Case<LoopLikeOpInterface>([&](auto loopOp) {
for (const auto &[yield, init, result]: llvm::zip_equal(loopOp.getYieldedValues(), loopOp.getInits(), results)) {
propagateResults(getLatticeElement(yield), {result});
propagateResults(getLatticeElement(init), {result});
}
})
.Case<triton::ReduceOp>([&](auto reduceOp) {
for (auto operand : operands) {
propagateUse(operand, UseType::DataUse);
}
})
.Case<hivm::FixpipeOp>([&](auto fixpipeOp) {
propagateUse(operands[0], UseType::DataUse);
})
.Case<hivm::CopyOp>([&](auto copyOp) {
propagateUse(operands[0], UseType::DataUse);
})
.Default([&](Operation *op) {
for (auto operand : operands) {
propagateResults(operand, results);
}
});
#if LLVM_VERSION_MAJOR >= 20
return success();
#endif
}
void setMixUseRecursively(Operation *rootOp, bool applyRoot = true) {
traverseBackwardUpdateOperandChainIf(
rootOp,
[rootOp, applyRoot](Operation *curOp) {
for (auto res : curOp->getResults()) {
auto tensorType = dyn_cast<RankedTensorType>(res.getType());
if (tensorType && isa<triton::PointerType>(tensorType.getElementType()))
return false;
}
return isMetaUse(curOp) && (curOp != rootOp || applyRoot);
},
[rootOp](Operation *curOp) {
return isa<triton::LoadOp>(curOp) && curOp != rootOp;
},
[](OpBuilder &b, Operation *op) {
LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(b.getContext())); });
op->removeAttr("MetaUse");
});
}
static void setMixUseFromValue(Value v)
{
if (auto *defOp = v.getDefiningOp()) {
setMixUseRecursively(defOp);
return;
}
auto blockArg = dyn_cast<BlockArgument>(v);
if (!blockArg) {
return;
}
auto *parentOp = blockArg.getOwner()->getParentOp();
auto loopLikeOp = dyn_cast_or_null<LoopLikeOpInterface>(parentOp);
if (!loopLikeOp) {
return;
}
if (OpOperand *init = loopLikeOp.getTiedLoopInit(blockArg)) {
if (auto *initDefOp = init->get().getDefiningOp())
setMixUseRecursively(initDefOp);
}
if (OpOperand *yielded = loopLikeOp.getTiedLoopYieldedValue(blockArg)) {
if (auto *yieldDefOp = yielded->get().getDefiningOp())
setMixUseRecursively(yieldDefOp);
}
}
std::optional<bool> isIterArgMixUse(Value v, Value target, const DataFlowSolver &solver) {
auto defOp = v.getDefiningOp();
auto *use = solver.lookupState<UseInfo>(v);
if ((use && use->type == UseType::DataUse) ||
isa_and_nonnull<LoopLikeOpInterface, scf::IfOp>(defOp))
return true;
if (v == target)
return false;
if (!defOp)
return std::nullopt;
for (auto oper : defOp->getOperands()) {
auto res = isIterArgMixUse(oper, target, solver);
if (res.has_value())
return res.value() || !isMetaUse(defOp);
}
return std::nullopt;
}
void postProcessWhileOp(scf::WhileOp op, const DataFlowSolver &solver) {
for (const auto &[res, arg] :
llvm::zip_equal(op->getResults(), op.getConditionOp().getArgs())) {
auto *defOp = arg.getDefiningOp();
if (!defOp)
continue;
auto *use = solver.lookupState<UseInfo>(res);
if (use && use->type == UseType::DataUse)
setMixUseRecursively(defOp);
}
for (const auto &[yield, regionArg] :
llvm::zip_equal(op.getYieldOp().getOperands(), op.getBeforeArguments())) {
auto *defOp = yield.getDefiningOp();
if (!defOp)
continue;
if (isIterArgMixUse(yield, regionArg, solver).value_or(false))
setMixUseRecursively(defOp);
}
}
void postProcessLoopOp(LoopLikeOpInterface loopOp, const DataFlowSolver &solver) {
if (auto whileOp = dyn_cast<scf::WhileOp>(loopOp.getOperation())) {
postProcessWhileOp(whileOp, solver);
return;
}
for (const auto &[res, yield, regionArg] :
llvm::zip_equal(loopOp->getResults(), loopOp.getYieldedValues(),
loopOp.getRegionIterArgs())) {
auto *defOp = yield.getDefiningOp();
if (!defOp)
continue;
auto *use = solver.lookupState<UseInfo>(res);
if ((use && use->type == UseType::DataUse) ||
isIterArgMixUse(yield, regionArg, solver).value_or(false))
setMixUseRecursively(defOp);
}
}
LogicalResult triton::runUseAnalysis(triton::FuncOp &funcOp) {
MLIRContext *context = funcOp.getContext();
SymbolTableCollection symbolTable;
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
solver.load<UseAnalysis>(symbolTable);
if (failed(solver.initializeAndRun(funcOp))) {
return failure();
}
auto &os = llvm::dbgs();
funcOp.walk([&](Operation *op) {
LLVM_DEBUG({ os << "[UseAnalysis] op is " << *op << "\n"; });
UseType useType = UseType::Undefined;
for (auto result : op->getResults()) {
LLVM_DEBUG({ os << "[UseAnalysis] ===> result is " << result << "\n"; });
auto use = solver.lookupState<UseInfo>(result);
assert(use && "Lattice value not found");
auto thisUseType = use->type;
LLVM_DEBUG({
os << "[UseAnalysis] ==========> useType is "
<< stringifyUseType(thisUseType) << "\n";
});
if (thisUseType == UseType::Undefined) {
continue;
}
if (useType == UseType::Undefined) {
useType = thisUseType;
}
if (thisUseType == UseType::MixUse || thisUseType != useType) {
useType = UseType::MixUse;
break;
}
}
if (useType == UseType::Undefined) {
LLVM_DEBUG({ op->setAttr("Undefined", UnitAttr::get(context)); });
return;
} else if (useType == UseType::MetaUse) {
auto memEffect = dyn_cast<MemoryEffectOpInterface>(op);
if (memEffect) {
if (isa<triton::AtomicRMWOp, triton::AtomicCASOp>(op)) {
LLVM_DEBUG({
os << "force protecting side-effect op:" << *op <<"\n";
});
op->setAttr("DataUse", UnitAttr::get(context));
return;
}
}
if (!isa<mlir::scf::IfOp, mlir::scf::ForOp, mlir::scf::WhileOp, triton::ReduceOp>(op)) {
assert(op->getNumResults() == 1 &&
"Ops used for meta computation are expected to have one result");
}
for (auto it = 0; it < op->getNumResults(); ++it) {
if (isa<ShapedType>(op->getResult(it).getType()) ||
(isa<triton::LoadOp>(op) &&
op->hasAttr(ConverterUtils::discreteAttrName)) ||
(isa<triton::BitcastOp>(op) &&
isa<PointerType>(op->getResult(it).getType()))) {
op->setAttr("MetaUse", UnitAttr::get(context));
}
}
return;
} else if (useType == UseType::DataUse) {
LLVM_DEBUG({ op->setAttr("DataUse", UnitAttr::get(context)); });
return;
}
assert(useType == UseType::MixUse);
bool shapedResult = true;
for (auto result : op->getResults())
shapedResult &= isa<ShapedType>(result.getType());
if (!shapedResult || isa<LoopLikeOpInterface, scf::IfOp, arith::SelectOp>(op)) {
LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); });
return;
}
llvm::SetVector<Operation *> metaUsers;
for (auto result : op->getResults()) {
for (auto user : result.getUsers()) {
TypeSwitch<Operation *>(user)
.Case<triton::LoadOp>([&](auto load) {
auto ptr = load.getPtr();
auto mask = load.getMask();
auto other = load.getOther();
if (result == ptr || result == mask || result == other) {
metaUsers.insert(user);
}
})
.Case<triton::StoreOp>([&](auto store) {
auto ptr = store.getPtr();
auto mask = store.getMask();
if (result == ptr || result == mask) {
metaUsers.insert(user);
}
})
.Case<triton::ascend::IndirectStoreOp>([&](auto indirectstore) {
auto src = indirectstore.getSrc();
auto offset = indirectstore.getOffsets();
auto mask = indirectstore.getMask();
if (result == src || result == offset ||
result == mask) {
metaUsers.insert(user);
}
})
.Case<triton::AtomicRMWOp>([&](auto atomicOp) {
auto ptr = atomicOp.getPtr();
auto mask = atomicOp.getMask();
if (result == ptr || result == mask)
metaUsers.insert(user);
})
.Case<triton::AtomicCASOp>([&](auto atomicOp) {
auto ptr = atomicOp.getPtr();
if (result == ptr)
metaUsers.insert(user);
})
.Case<triton::DotOp>([&](auto dot) {
auto opc = dot.getC();
triton::SplatOp splat;
if (opc) {
splat = opc.template getDefiningOp<triton::SplatOp>();
}
if (opc && splat &&
splat.getSrc().getDefiningOp<arith::ConstantOp>()) {
metaUsers.insert(user);
}
})
.Case<triton::PrintOp>([&](auto print) {
})
.Default([&](Operation *op) {
bool allMeta = true;
for (auto res : op->getResults()) {
auto resUse = solver.lookupState<UseInfo>(res);
if (resUse->type != UseType::MetaUse) {
allMeta = false;
break;
}
}
if (allMeta) {
metaUsers.insert(user);
}
});
}
}
if (metaUsers.empty()) {
LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); });
return;
}
if (isa<LoopLikeOpInterface, scf::IfOp>(op))
return;
if (isa<triton::LoadOp>(op))
return;
OpBuilder builder(op);
auto clone = builder.clone(*op);
LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); });
clone->setAttr("MetaUse", UnitAttr::get(context));
for (auto [res_i, result] : llvm::enumerate(op->getResults())) {
for (auto user : metaUsers) {
for (auto &operand : user->getOpOperands()) {
if (operand.get() == result) {
operand.set(clone->getResult(res_i));
}
}
}
}
});
LLVM_DEBUG({
os << "[UseAnalysis] Before post-process, funcOp is " << *funcOp << "\n";
});
funcOp.walk([&](Operation *op) {
if (opIsIndirectLoad(op) || opIsIndirectCalc(op) || isa<triton::ascend::IndirectStoreOp>(op)) {
LLVM_DEBUG({
os << "[UseAnalysis] Found indirect load interface op: " << *op << "\n";
});
llvm::SmallPtrSet<Operation *, 16> stopOps;
traverseForwardUpdateUserChainIf(
op,
[op](Operation *curOp) { return isMetaUse(curOp) && curOp != op; },
[&](Operation *curOp) {
return (isa<triton::LoadOp>(curOp) || isa<triton::StoreOp>(curOp) || isa<triton::ascend::IndirectStoreOp>(curOp))
&& !isMetaUse(curOp);
},
[](OpBuilder &b, Operation *op) {
setMixUseRecursively(op);
},
stopOps);
LLVM_DEBUG({
os << "[UseAnalysis] stopOps are \n";
for (auto [idx, stopOp] : llvm::enumerate(stopOps))
os << idx << ": " << *stopOp << "\n";
});
LLVM_DEBUG({
os << "[UseAnalysis] After trace, funcOp is " << *funcOp << "\n";
});
for (auto *stopOp : stopOps)
setMixUseRecursively(stopOp, false);
LLVM_DEBUG({
os << "[UseAnalysis] After traceback of stopOp, funcOp is " << *funcOp
<< "\n";
});
LLVM_DEBUG({ op->setAttr("MixUse", UnitAttr::get(context)); });
op->removeAttr("MetaUse");
}
if (op->hasAttr(ConverterUtils::discreteAttrName))
setMixUseRecursively(op);
if (auto loopOp = dyn_cast<LoopLikeOpInterface>(op)) {
postProcessLoopOp(loopOp, solver);
} else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
SmallVector<Value> yields(ifOp.thenYield().getOperands());
if (!ifOp.getElseRegion().empty())
yields.append(llvm::to_vector(ifOp.elseYield().getOperands()));
for (auto yield : yields) {
setMixUseFromValue(yield);
}
} else if(auto atomicRmwOp = dyn_cast<triton::AtomicRMWOp>(op)) {
auto mask = atomicRmwOp.getMask();
if(mask && op->hasAttr(ConverterUtils::discreteMaskAttrName))
setMixUseRecursively(mask.getDefiningOp());
}
});
funcOp.walk([&](Operation *op) {
if (isMetaUse(op) && isMixUse(op)) {
op->removeAttr("MetaUse");
}
});
funcOp.walk([&](hivm::CustomOp op) {
if (isMetaUse(op)) {
op->removeAttr("MetaUse");
}
});
LLVM_DEBUG({
os << "[UseAnalysis] After post-process, funcOp is " << *funcOp << "\n";
});
return success();
}
MetaUseEraser::MetaUseEraser(MLIRContext *context)
: RewritePattern(MatchAnyOpTypeTag(), 10, context) {}
LogicalResult MetaUseEraser::matchAndRewrite(Operation *op,
PatternRewriter &rewriter) const {
LLVM_DEBUG({
int64_t count = 0;
for (auto result : op->getResults()) {
count += std::distance(result.use_begin(), result.use_end());
}
llvm::dbgs() << "Number of user: " << count << "\n";
});
if (isa<triton::AddPtrOp>(op)) {
return rewriter.notifyMatchFailure(op,
"AddPtrOp will be handled separately");
}
if (isMetaUse(op)) {
rewriter.eraseOp(op);
return success();
}
return rewriter.notifyMatchFailure(op, "requires meta ops");
}