#ifndef TRITON_ANALYSIS_UTILITY_H
#define TRITON_ANALYSIS_UTILITY_H
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Tools/LinearLayout.h"
namespace mlir {
inline bool isZeroConst(Value v) {
auto constantOp = v.getDefiningOp<arith::ConstantOp>();
if (!constantOp)
return false;
if (auto denseAttr = dyn_cast<DenseFPElementsAttr>(constantOp.getValueAttr()))
return denseAttr.isSplat() && denseAttr.getSplatValue<APFloat>().isZero();
if (auto denseAttr =
dyn_cast<DenseIntElementsAttr>(constantOp.getValueAttr()))
return denseAttr.isSplat() && denseAttr.getSplatValue<APInt>().isZero();
return false;
}
class ReduceOpHelper {
public:
explicit ReduceOpHelper(triton::ReduceOp op)
: op(op.getOperation()), axis(op.getAxis()) {
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
srcTy = firstTy;
srcShape = firstTy.getShape();
srcEncoding = firstTy.getEncoding();
srcElementTypes = op.getElementTypes();
for (const auto &t : op.getInputTypes()) {
if (t.getShape() != srcShape) {
op.emitError() << "shape mismatch";
}
if (t.getEncoding() != srcEncoding) {
op.emitError() << "encoding mismatch";
}
}
}
ArrayRef<int64_t> getSrcShape() { return srcShape; }
Attribute getSrcLayout() { return srcEncoding; }
triton::ReduceOp getOperation() { return op; }
unsigned getThreadOffsetOnReductionAxis();
bool isWarpSynchronous();
unsigned getInterWarpSizeWithUniqueData();
unsigned getIntraWarpSizeWithUniqueData();
SmallVector<unsigned> getScratchRepShape();
SmallVector<unsigned> getOrderWithAxisAtBeginning();
unsigned getScratchSizeInBytes();
bool isReduceWithinCTA();
bool isAssociative();
private:
triton::ReduceOp op;
RankedTensorType srcTy;
ArrayRef<int64_t> srcShape;
Attribute srcEncoding;
SmallVector<Type> srcElementTypes;
int axis;
};
class ScanLoweringHelper {
public:
explicit ScanLoweringHelper(triton::ScanOp op) : scanOp(op) {
auto firstTy = cast<RankedTensorType>(op.getOperands()[0].getType());
srcShape = firstTy.getShape();
legacyEncoding = firstTy.getEncoding();
srcEncoding = triton::gpu::toLinearEncoding(firstTy);
srcElementTypes = op.getElementTypes();
if (auto blockedEncoding =
dyn_cast<triton::gpu::BlockedEncodingAttr>(legacyEncoding)) {
order = llvm::to_vector(blockedEncoding.getOrder());
} else {
order = srcEncoding.getOrder();
}
for (const auto &t : op.getInputTypes()) {
if (t.getShape() != srcShape) {
op.emitError() << "shape mismatch";
}
if (t.getEncoding() != legacyEncoding) {
op.emitError() << "encoding mismatch";
}
}
}
bool isSupported();
unsigned getAxisNumElementsPerThread();
unsigned getNonAxisNumElementsPerThread();
unsigned getNonAxisNumThreadsPerWarp();
unsigned getNonAxisNumThreadsPerCTA();
unsigned getAxisNumWarpsWithUniqueData();
unsigned getAxisNumThreadsPerWarpWithUniqueData();
unsigned getAxisNumBlocks();
unsigned getNonAxisNumBlocks();
unsigned getScratchSizeInBytes();
unsigned getScratchSizeInElems();
unsigned getAxisElementStride();
unsigned getAxisThreadStride();
unsigned getAxisBlockStride();
Location getLoc() { return scanOp.getLoc(); }
unsigned getAxis() { return scanOp.getAxis(); }
bool getReverse() { return scanOp.getReverse(); }
triton::gpu::LinearEncodingAttr getEncoding() { return srcEncoding; }
llvm::ArrayRef<int64_t> getShape() { return srcShape; }
unsigned getNumOperands() { return scanOp.getNumOperands(); }
SmallVector<Type> getElementTypes() { return srcElementTypes; }
SmallVector<unsigned> getOrder() { return order; }
Region &getCombineOp();
private:
triton::ScanOp scanOp;
triton::gpu::LinearEncodingAttr srcEncoding;
Attribute legacyEncoding;
llvm::ArrayRef<int64_t> srcShape;
SmallVector<Type> srcElementTypes;
SmallVector<unsigned> order;
};
class GatherLoweringHelper {
public:
GatherLoweringHelper(triton::GatherOp gatherOp);
unsigned getScratchSizeInBytes();
bool isWarpLocal();
private:
triton::GatherOp gatherOp;
RankedTensorType srcTy;
RankedTensorType dstTy;
};
struct DecomposedWarpConversion {
struct TranspositionInfo {
std::pair<int, int> transposition;
uint16_t topPreSel = 0x3210;
uint16_t botPreSel = 0x7654;
uint16_t topPostSel = 0x3210;
uint16_t botPostSel = 0x7654;
};
triton::LinearLayout pReg, pLane;
SmallVector<TranspositionInfo> mixedTranspositions;
int nPack;
};
DecomposedWarpConversion
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
RankedTensorType dstTy, int bitwidth);
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>>
getReshapeDecomposition(ArrayRef<int64_t> srcShape, ArrayRef<int64_t> dstShape);
unsigned getNumScratchElements(ArrayRef<unsigned> shape);
bool supportWMMA(triton::DotOp op);
bool supportMMA(triton::DotOp op, int version);
bool supportMMA(Value value, int version);
triton::LinearLayout minimalCvtLayout(Type srcTy, Type dstTy);
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy);
bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy);
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy);
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);
SetVector<Operation *>
multiRootTopologicalSort(const SetVector<Operation *> &toSort);
SetVector<Operation *>
multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr,
TransitiveFilter forwardFilter = nullptr);
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
template <typename T> class CallGraph {
public:
using FuncDataMapT = DenseMap<FunctionOpInterface, T>;
explicit CallGraph(ModuleOp moduleOp) : moduleOp(moduleOp) { build(); }
template <WalkOrder UpdateEdgeOrder = WalkOrder::PreOrder,
WalkOrder UpdateNodeOrder = WalkOrder::PreOrder,
typename UpdateEdgeFn, typename UpdateNodeFn>
void walk(UpdateEdgeFn updateEdgeFn, UpdateNodeFn updateNodeFn) {
DenseSet<FunctionOpInterface> visited;
for (auto root : roots) {
doWalk<UpdateEdgeOrder, UpdateNodeOrder>(root, visited, updateEdgeFn,
updateNodeFn);
}
}
T *getFuncData(FunctionOpInterface funcOp) {
if (funcMap.count(funcOp)) {
return &funcMap[funcOp];
}
return nullptr;
}
ModuleOp getModuleOp() const { return moduleOp; }
SmallVector<FunctionOpInterface> getRoots() const { return roots; }
size_t getNumFunctions() const { return funcMap.size(); }
bool isRoot(FunctionOpInterface funcOp) const {
return llvm::is_contained(roots, funcOp);
}
template <typename FROM, typename TO>
void mapFuncOp(FROM funcOp, TO targetFuncOp) {
for (auto &kv : graph) {
for (auto &edge : kv.second) {
if (edge.second == funcOp) {
edge.second = targetFuncOp;
}
}
}
graph[targetFuncOp] = graph[funcOp];
for (auto it = roots.begin(); it != roots.end(); ++it) {
if (*it == funcOp) {
*it = targetFuncOp;
break;
}
}
funcMap[targetFuncOp] = funcMap[funcOp];
}
template <typename FROM, typename TO>
void mapCallOp(FROM callOp, TO targetCallOp) {
for (auto &kv : graph) {
for (auto &edge : kv.second) {
if (edge.first == callOp) {
edge.first = targetCallOp;
}
}
}
}
private:
void build() {
SymbolTableCollection symbolTable;
DenseSet<FunctionOpInterface> visited;
moduleOp.walk([&](Operation *op) {
auto caller = op->getParentOfType<FunctionOpInterface>();
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto *callee = callOp.resolveCallableInTable(&symbolTable);
auto funcOp = dyn_cast_or_null<FunctionOpInterface>(callee);
if (funcOp) {
graph[caller].emplace_back(
std::pair<CallOpInterface, FunctionOpInterface>(callOp, funcOp));
visited.insert(funcOp);
}
}
});
moduleOp.walk([&](FunctionOpInterface funcOp) {
if (!visited.count(funcOp)) {
roots.push_back(funcOp);
}
});
}
template <WalkOrder UpdateEdgeOrder = WalkOrder::PreOrder,
WalkOrder UpdateNodeOrder = WalkOrder::PreOrder,
typename UpdateEdgeFn, typename UpdateNodeFn>
void doWalk(FunctionOpInterface funcOp,
DenseSet<FunctionOpInterface> &visited, UpdateEdgeFn updateEdgeFn,
UpdateNodeFn updateNodeFn) {
if (visited.count(funcOp)) {
llvm::report_fatal_error("Cycle detected in call graph");
}
if constexpr (UpdateNodeOrder == WalkOrder::PreOrder) {
updateNodeFn(funcOp);
}
for (auto [callOp, callee] : graph[funcOp]) {
if constexpr (UpdateEdgeOrder == WalkOrder::PreOrder) {
updateEdgeFn(callOp, callee);
}
doWalk<UpdateEdgeOrder, UpdateNodeOrder>(callee, visited, updateEdgeFn,
updateNodeFn);
if constexpr (UpdateEdgeOrder == WalkOrder::PostOrder) {
updateEdgeFn(callOp, callee);
}
}
if constexpr (UpdateNodeOrder == WalkOrder::PostOrder) {
updateNodeFn(funcOp);
}
visited.erase(funcOp);
}
protected:
ModuleOp moduleOp;
DenseMap<FunctionOpInterface,
SmallVector<std::pair<CallOpInterface, FunctionOpInterface>>>
graph;
FuncDataMapT funcMap;
SmallVector<FunctionOpInterface> roots;
};
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
const triton::LinearLayout &dstLayout);
}
#endif