#include "triton/Analysis/Utility.h"
#include <deque>
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Allocation.h"
#include "triton/Conversion/MLIRTypes.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/ADT/SmallSet.h"
namespace mlir {
using namespace triton;
using namespace triton::gpu;
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
auto order = toLinearEncoding(srcTy).getOrder();
auto it = std::find(order.begin(), order.end(), axis);
order.erase(it);
order.insert(order.begin(), axis);
return order;
}
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
auto *ctx = srcEncoding.getContext();
auto linearLayout = toLinearLayout(srcTy);
auto kLane = mlir::StringAttr::get(ctx, "lane");
const auto &bases = linearLayout.getBases();
const auto &lanes = bases.find(kLane)->second;
auto offset = 1;
for (const auto &lane : lanes) {
if (lane[axis] != 0)
break;
offset *= 2;
}
return offset;
}
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
unsigned numCTAs = getNumCTAs(srcLayout);
assert(numCTAs == getNumCTAs(dstLayout) &&
"Invalid layout conversion: the numbers of CTAs of src and dst "
"layouts are different");
if (numCTAs == 1)
return false;
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
auto dim = sliceLayout.getDim();
auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent());
if (CTAsPerCGA[dim] != 1)
llvm::report_fatal_error("Layout conversion to be implemented");
}
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(dstLayout)) {
auto dim = sliceLayout.getDim();
auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent());
if (CTAsPerCGA[dim] != 1)
return true;
}
auto srcCTALayout = getCTALayout(srcLayout);
auto dstCTALayout = getCTALayout(dstLayout);
if (srcCTALayout == dstCTALayout)
return false;
return true;
}
unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
return getWarpsPerCTA(srcEncoding, srcShape)[axis];
}
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
return getThreadsPerWarp(srcEncoding, srcShape)[axis];
}
bool ReduceOpHelper::isWarpSynchronous() {
return getWarpsPerCTA(srcEncoding, srcShape)[axis] == 1;
}
SmallVector<unsigned> ReduceOpHelper::getScratchRepShape() {
SmallVector<unsigned> smemShape;
if (isWarpSynchronous())
return {0, 0};
smemShape = convertType<unsigned>(srcShape);
smemShape[axis] = getInterWarpSizeWithUniqueData();
return smemShape;
}
unsigned ReduceOpHelper::getScratchSizeInBytes() {
auto smemShape = getScratchRepShape();
auto elems = product<unsigned>(smemShape);
unsigned bytesPerElem = 0;
for (const auto &ty : srcElementTypes) {
bytesPerElem += ceil<unsigned>(ty.getIntOrFloatBitWidth(), 8);
}
return bytesPerElem * elems;
}
bool ReduceOpHelper::isReduceWithinCTA() {
return getCTASplitNum(srcEncoding)[axis] == 1;
}
bool ReduceOpHelper::isAssociative() {
auto dtype = srcElementTypes[0];
if (!type::isFloat(dtype))
return true;
size_t reduce_size = srcShape[axis];
if (reduce_size <= 2)
return true;
bool hasNoAssociativeOp = false;
op.walk([&](Operation *nestedOp) -> WalkResult {
if (isa<arith::AddFOp, arith::MulFOp>(nestedOp)) {
hasNoAssociativeOp = true;
return WalkResult::interrupt();
}
return WalkResult::advance();
});
return !hasNoAssociativeOp;
}
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
return getEncoding().getContigPerThread()[getAxis()];
}
unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
auto contigPerThread = getEncoding().getContigPerThread();
contigPerThread[getAxis()] = 1;
return product<unsigned>(contigPerThread);
}
Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); }
unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() {
return getEncoding().getThreadsPerWarp()[getAxis()];
}
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() {
auto nThreads = product(getEncoding().getThreadsPerWarp());
return nThreads / getAxisNumThreadsPerWarpWithUniqueData();
}
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
auto nWarps = product(getEncoding().getWarpsPerCTA());
return (nWarps / getAxisNumWarpsWithUniqueData()) *
getNonAxisNumThreadsPerWarp();
}
unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() {
return getEncoding().getWarpsPerCTA()[getAxis()];
}
unsigned ScanLoweringHelper::getAxisNumBlocks() {
auto contigPerThread = getEncoding().getContigPerThread();
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
unsigned axis = getAxis();
return ceil<unsigned>(
getShape()[axis],
(contigPerThread[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
}
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
auto contigPerThread = getEncoding().getContigPerThread();
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
auto rank = contigPerThread.size();
unsigned axis = getAxis();
unsigned numBlocks = 1;
for (unsigned i = 0; i < rank; i++) {
if (i == axis)
continue;
numBlocks *=
ceil<unsigned>(getShape()[i], (contigPerThread[i] * threadsPerWarp[i] *
warpsPerCTA[i]));
}
return numBlocks;
}
bool ScanLoweringHelper::isSupported() {
if (!isa<BlockedEncodingAttr>(legacyEncoding))
return false;
return true;
}
unsigned ScanLoweringHelper::getScratchSizeInElems() {
unsigned numWarps = product(getEncoding().getWarpsPerCTA());
unsigned numNonAxisElementsPerWarp =
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
unsigned numElements = numWarps * numNonAxisElementsPerWarp *
getAxisNumBlocks() * getNonAxisNumBlocks();
return numElements;
}
unsigned ScanLoweringHelper::getScratchSizeInBytes() {
if (!isSupported())
return 0;
unsigned axisNumWarps = getAxisNumWarpsWithUniqueData();
if (axisNumWarps == 1)
return 0;
unsigned elementSizeInBytes = 0;
for (const auto &ty : srcElementTypes) {
elementSizeInBytes += ceil<unsigned>(ty.getIntOrFloatBitWidth(), 8);
}
return elementSizeInBytes * getScratchSizeInElems();
}
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
std::vector<std::vector<int32_t>> ®Bases,
int bitwidth);
DecomposedWarpConversion
getWarpLayoutConvertDecomposition(RankedTensorType srcTy,
RankedTensorType dstTy, int bitwidth) {
auto srcLayout = toLinearLayout(srcTy);
auto dstLayout = toLinearLayout(dstTy);
auto removeBroadcastSrc = actionRemoveBroadcastedRegs(srcLayout);
auto removeBroadcastDst = actionRemoveBroadcastedRegs(dstLayout);
srcLayout = removeBroadcastSrc.apply(srcLayout);
dstLayout = removeBroadcastDst.apply(dstLayout);
auto *ctx = srcTy.getContext();
StringAttr kReg = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
int nSrcRegBases = srcLayout.getInDimSizeLog2(kReg);
int nDstRegBases = dstLayout.getInDimSizeLog2(kReg);
int nSrcLaneBases = srcLayout.getInDimSizeLog2(kLane);
int nDstLaneBases = dstLayout.getInDimSizeLog2(kLane);
int nRegBases = std::max(nSrcRegBases, nDstRegBases);
int nLaneBases = std::max(nSrcLaneBases, nDstLaneBases);
SmallVector<StringAttr> inDimNames{kReg, kLane};
auto outDimNames = llvm::to_vector(srcLayout.getOutDimNames());
auto S = srcLayout.sublayout(inDimNames, outDimNames);
auto T = dstLayout.sublayout(inDimNames, outDimNames);
if (nSrcRegBases != nDstRegBases || nSrcLaneBases != nDstLaneBases) {
auto padWithZeros = [&](const LinearLayout &ll) {
auto newBases = ll.getBases();
auto padDim = [&](StringAttr dim, int dimSize) {
auto &dimBases = newBases[dim];
dimBases.reserve(dimSize);
for (int i = ll.getInDimSizeLog2(dim); i < dimSize; ++i)
dimBases.emplace_back(outDimNames.size(), 0);
};
padDim(kReg, nRegBases);
padDim(kLane, nLaneBases);
return LinearLayout(std::move(newBases), ll.getOutDims(),
false);
};
S = padWithZeros(S);
T = padWithZeros(T);
}
auto pBases = S.invertAndCompose(T).getBases();
S = S.flattenOuts();
T = T.flattenOuts();
SmallVector<std::pair<int32_t, int32_t>> srcFreeZeros;
SmallVector<std::pair<int32_t, int32_t>> dstFreeZeros;
for (auto [dimIdx, dim] : llvm::enumerate(inDimNames)) {
for (int inIdx = 0; inIdx < S.getInDimSizeLog2(dim); ++inIdx) {
int sVal = S.getBasis(dim, inIdx)[0];
int tVal = T.getBasis(dim, inIdx)[0];
if (sVal == 0 && tVal == 0) {
pBases[dim][inIdx][dimIdx] = 1 << inIdx;
} else if (sVal == 0) {
srcFreeZeros.emplace_back(dimIdx, inIdx);
} else if (tVal == 0) {
dstFreeZeros.emplace_back(dimIdx, inIdx);
}
}
}
for (auto [srcZeroLoc, dstZeroLoc] : llvm::zip(srcFreeZeros, dstFreeZeros)) {
auto [srcDimIdx, srcIdx] = srcZeroLoc;
auto [dstDimIdx, dstIdx] = dstZeroLoc;
auto inDim = inDimNames[srcDimIdx];
pBases[inDim][srcIdx][dstDimIdx] = 1 << dstIdx;
}
LinearLayout::BasesT pRegBases, pLaneBases;
auto ®Bases = pRegBases[kReg];
auto &laneBases = pLaneBases[kLane];
regBases.resize(nRegBases, {0});
laneBases.resize(nLaneBases, {0});
SmallVector<std::pair<int, int>> mixedTranspositions;
llvm::BitVector visited(nRegBases + nLaneBases, false);
auto flatIdx = [&](StringAttr dim, int32_t index) {
return (dim == kReg) ? index : nRegBases + index;
};
for (auto dim : inDimNames) {
int inDimSize = S.getInDimSizeLog2(dim);
for (int i = 0; i < inDimSize; ++i) {
if (visited.test(flatIdx(dim, i)))
continue;
StringAttr entryDim = dim;
int32_t entryIdx = i;
StringAttr currDim = entryDim;
int32_t currIdx = entryIdx;
int32_t regStartIdx = -1;
int32_t laneStartIdx = -1;
int32_t laneEndIdx = -1;
int32_t regEndIdx = -1;
do {
visited.set(flatIdx(currDim, currIdx));
auto nextVec = pBases.lookup(currDim)[currIdx];
StringAttr nextDim;
int32_t nextIdx;
for (auto [nextDimIdx, nextVal] : llvm::enumerate(nextVec)) {
if (nextVal != 0) {
nextDim = inDimNames[nextDimIdx];
nextIdx = llvm::Log2_32(nextVal);
}
}
if (currDim == kReg && nextDim == kReg) {
regBases[currIdx][0] = 1 << nextIdx;
} else if (currDim == kLane && nextDim == kLane) {
laneBases[currIdx][0] = 1 << nextIdx;
} else if (currDim == kReg && nextDim == kLane) {
regStartIdx = currIdx;
laneStartIdx = nextIdx;
} else {
regEndIdx = nextIdx;
laneEndIdx = currIdx;
}
if (regEndIdx >= 0) {
regBases[regStartIdx][0] = 1 << regEndIdx;
laneBases[laneEndIdx][0] = 1 << laneStartIdx;
mixedTranspositions.emplace_back(regEndIdx, laneStartIdx);
regStartIdx = laneStartIdx = laneEndIdx = regEndIdx = -1;
}
currDim = nextDim;
currIdx = nextIdx;
} while (flatIdx(currDim, currIdx) != flatIdx(entryDim, entryIdx));
}
}
assert(visited.all() && "Cycle walk incomplete");
int m = mixedTranspositions.size();
int nPackPrelim = llvm::Log2_32(std::clamp(32 / bitwidth, 1, 4));
int nPack = std::min(nPackPrelim, nRegBases - m);
auto processedTranspos =
getTranspositionSelectors(mixedTranspositions, regBases, nPack);
auto pReg = LinearLayout(std::move(pRegBases), {{kReg, 1 << nRegBases}},
true);
auto pLane = LinearLayout(std::move(pLaneBases), {{kLane, 1 << nLaneBases}},
true);
return {std::move(pReg), std::move(pLane), std::move(processedTranspos),
nPack};
}
static SmallVector<DecomposedWarpConversion::TranspositionInfo>
getTranspositionSelectors(SmallVector<std::pair<int, int>> &mixedTranspositions,
std::vector<std::vector<int32_t>> ®Bases,
int nPack) {
SmallVector<DecomposedWarpConversion::TranspositionInfo> ret;
ret.reserve(mixedTranspositions.size());
if (nPack == 0) {
for (auto &t : mixedTranspositions)
ret.push_back(DecomposedWarpConversion::TranspositionInfo{t});
return ret;
}
auto permuteSelector = [nPack](uint16_t sel, int bitIdx) {
int lo = bitIdx + (2 - nPack);
uint16_t maskHi = 0x4444;
uint16_t maskLo = 0x1111 << lo;
uint16_t fixed = sel & ~maskHi & ~maskLo;
int shift = 2 - lo;
return fixed | ((maskHi & sel) >> shift) | ((maskLo & sel) << shift);
};
auto generateSelectors = [&](int head, int tail, auto &&lowBits) {
uint16_t topSel = 0x3210;
uint16_t botSel = 0x7654;
for (auto lowBit : lowBits) {
topSel = permuteSelector(topSel, lowBit);
botSel = permuteSelector(botSel, lowBit);
if (lowBit != head && lowBit != tail)
regBases[lowBit][0] = 1 << lowBit;
}
return std::pair{topSel, botSel};
};
llvm::SmallSet<int32_t, 6> pairedRegBits;
for (auto [rBit, lBit] : mixedTranspositions)
pairedRegBits.insert(rBit);
auto next = [&](int b) { return llvm::Log2_32(regBases[b][0]); };
auto nextHighFree = [&](auto p) {
int curr = p.first;
do {
if (curr >= nPack)
return curr == p.first || !pairedRegBits.contains(curr);
curr = next(curr);
} while (curr != p.first);
return false;
};
std::stable_partition(mixedTranspositions.begin(), mixedTranspositions.end(),
nextHighFree);
auto prev = [&](int b) {
int tail = b;
int curr = next(b);
while (curr != b) {
tail = curr;
curr = next(curr);
}
return tail;
};
auto findPartner = [&](int lowBit, auto &preShufLoBits) {
if (nPack == 2) {
int otherLow = 1 - lowBit;
int b = next(otherLow);
if (next(lowBit) == lowBit && b >= nPack && !pairedRegBits.contains(b) &&
!pairedRegBits.contains(otherLow)) {
preShufLoBits.push_back(otherLow);
regBases[prev(otherLow)][0] = 1 << b;
pairedRegBits.insert(b);
return b;
}
}
int potentialPartner = nPack;
while (pairedRegBits.contains(potentialPartner))
++potentialPartner;
pairedRegBits.insert(potentialPartner);
return potentialPartner;
};
for (auto p : mixedTranspositions) {
int rBit = p.first;
int lBit = p.second;
SmallVector<int> cycle;
int currBit = rBit;
do {
cycle.push_back(currBit);
currBit = next(currBit);
} while (currBit != rBit);
auto isBoundary = [&](int bit) {
return bit >= nPack || (pairedRegBits.contains(bit) && bit != rBit);
};
auto forwardEnd = llvm::find_if(cycle, isBoundary);
auto backwardEnd = std::find_if(cycle.rbegin(), cycle.rend(), isBoundary);
SmallVector<int> postShufLoBits(cycle.begin(), forwardEnd);
SmallVector<int> preShufLoBits(cycle.rbegin(), backwardEnd);
int head;
int tail;
int partnerBit = -1;
if (forwardEnd != cycle.end()) {
if (*forwardEnd == rBit || !pairedRegBits.contains(*forwardEnd)) {
head = partnerBit = *forwardEnd;
} else {
head = postShufLoBits.front();
preShufLoBits.push_back(head);
postShufLoBits.resize(1);
pairedRegBits.erase(head);
}
tail = *backwardEnd;
if (tail < nPack && pairedRegBits.contains(tail)) {
preShufLoBits.insert(preShufLoBits.begin(), tail);
}
} else {
if (next(rBit) != rBit && pairedRegBits.contains(next(rBit))) {
preShufLoBits.erase(preShufLoBits.begin());
postShufLoBits.pop_back();
pairedRegBits.erase(postShufLoBits.front());
head = rBit;
tail = next(rBit);
} else {
if (postShufLoBits.size() == 2)
postShufLoBits.pop_back();
head = tail = preShufLoBits.front();
}
}
if (partnerBit < 0)
partnerBit = findPartner(head, preShufLoBits);
auto [topPostSel, botPostSel] =
generateSelectors(head, tail, llvm::reverse(postShufLoBits));
auto [topPreSel, botPreSel] = generateSelectors(head, tail, preShufLoBits);
regBases[tail][0] = 1 << head;
DecomposedWarpConversion::TranspositionInfo info;
info.transposition = {partnerBit, lBit};
info.topPreSel = topPreSel;
info.botPreSel = botPreSel;
info.topPostSel = topPostSel;
info.botPostSel = botPostSel;
if (!preShufLoBits.empty()) {
uint16_t sel = (nPack - preShufLoBits.back()) == 2 ? 0x6240 : 0x5410;
auto it =
llvm::find_if(ret, [&](auto &t) { return t.topPostSel == sel; });
ret.insert(it, info);
} else {
ret.push_back(info);
}
}
if (nPack == 2 && regBases[0][0] == 2 && regBases[1][0] == 1 && ret.size()) {
auto &t = ret.back();
t.topPostSel = 0x3120;
t.botPostSel = 0x7564;
}
return ret;
}
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>>
getReshapeDecomposition(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> dstShape) {
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>> ret;
if (srcShape.empty()) {
assert(dstShape.empty());
return ret;
}
ret.push_back({});
int srcIdx = 0;
int dstIdx = 0;
int srcNElems = 1;
int dstNElems = 1;
while (srcIdx < srcShape.size() || dstIdx < dstShape.size()) {
if (srcNElems < dstNElems ||
(srcIdx < srcShape.size() && srcNElems == 1) ||
(srcIdx < srcShape.size() && srcShape[srcIdx] == 1)) {
assert(srcIdx < srcShape.size());
srcNElems *= srcShape[srcIdx];
ret.back().first.push_back(srcIdx);
srcIdx++;
} else if (dstNElems < srcNElems ||
(dstIdx < dstShape.size() && dstShape[dstIdx] == 1)) {
assert(dstIdx < dstShape.size());
dstNElems *= dstShape[dstIdx];
ret.back().second.push_back(dstIdx);
dstIdx++;
} else {
ret.push_back({});
srcNElems = 1;
dstNElems = 1;
}
}
return ret;
}
unsigned ScanLoweringHelper::getAxisElementStride() {
auto order = getOrder();
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= getEncoding().getContigPerThread()[dim];
}
llvm_unreachable("Axis not found in order");
}
unsigned ScanLoweringHelper::getAxisThreadStride() {
auto encoding = getEncoding();
auto kThread = StringAttr::get(encoding.getContext(), "lane");
auto threadsPerWarp = encoding.basesPerDim(kThread, false);
auto order = getOrder();
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= threadsPerWarp[dim];
}
llvm_unreachable("Axis not found in order");
}
unsigned ScanLoweringHelper::getAxisBlockStride() {
auto order = getOrder();
unsigned stride = 1;
auto contigPerThread = getEncoding().getContigPerThread();
auto threadsPerWarp = getEncoding().getThreadsPerWarp();
auto warpsPerCTA = getEncoding().getWarpsPerCTA();
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= ceil<unsigned int>(getShape()[dim], contigPerThread[dim] *
threadsPerWarp[dim] *
warpsPerCTA[dim]);
}
llvm_unreachable("Axis not found in order");
}
GatherLoweringHelper::GatherLoweringHelper(triton::GatherOp gatherOp)
: gatherOp(gatherOp) {}
unsigned GatherLoweringHelper::getScratchSizeInBytes() {
if (isWarpLocal())
return 0;
RankedTensorType srcType = gatherOp.getSrc().getType();
return product(srcType.getShape()) *
ceil<unsigned>(srcType.getElementTypeBitWidth(), 8);
}
bool GatherLoweringHelper::isWarpLocal() {
RankedTensorType srcType = gatherOp.getSrc().getType();
RankedTensorType idxType = gatherOp.getIndices().getType();
LinearLayout srcLayout = toLinearLayout(srcType);
LinearLayout idxLayout = toLinearLayout(idxType);
Builder b(gatherOp.getContext());
StringAttr kBlock = b.getStringAttr("block");
StringAttr kWarp = b.getStringAttr("warp");
StringAttr kLane = b.getStringAttr("lane");
StringAttr kGatherDim =
b.getStringAttr("dim" + std::to_string(gatherOp.getAxis()));
if (!srcLayout.sublayoutIsZero({kBlock, kWarp}, kGatherDim) ||
!idxLayout.sublayoutIsZero({kBlock, kWarp}, kGatherDim))
return false;
SmallVector<StringAttr> otherDims;
for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) {
if (dim != gatherOp.getAxis()) {
otherDims.push_back(b.getStringAttr("dim" + Twine(dim)));
}
}
if (srcLayout.sublayout({kBlock, kWarp}, otherDims) !=
idxLayout.sublayout({kBlock, kWarp}, otherDims))
return false;
return srcLayout.sublayout(kLane, otherDims) ==
idxLayout.sublayout(kLane, otherDims);
}
unsigned getNumScratchElements(ArrayRef<unsigned> shape) {
if (shape.empty())
return 0;
return product<unsigned>(shape);
}
bool supportMMA(triton::DotOp op, int version) {
auto aElemTy = op.getA().getType().getElementType();
auto bElemTy = op.getB().getType().getElementType();
if (version == 5) {
if (triton::tools::getBoolEnv("DISABLE_MMA_V5"))
return false;
RankedTensorType typeA = op.getA().getType();
int k = typeA.getShape().back();
auto retType = op.getType();
auto retShapePerCTA = getShapePerCTA(retType);
auto rank = retShapePerCTA.size();
int numWarps = lookupNumWarps(op);
if (aElemTy.isInteger() || bElemTy.isInteger() ||
retType.getElementType().isInteger())
return false;
if (op.getType().getRank() != 2)
return false;
if (numWarps != 4 && numWarps != 8) {
return false;
}
if (k < 256 / aElemTy.getIntOrFloatBitWidth())
return false;
if (!(retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 16 == 0))
return false;
return true;
}
if (version == 3) {
if (triton::tools::getBoolEnv("DISABLE_MMA_V3"))
return false;
auto retType = op.getType();
RankedTensorType typeA = op.getA().getType();
int k = typeA.getShape().back();
if (k < 256 / aElemTy.getIntOrFloatBitWidth())
return false;
auto retShapePerCTA = getShapePerCTA(retType);
auto rank = retShapePerCTA.size();
int numWarps = lookupNumWarps(op);
if (rank == 3)
return false;
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 16 == 0 &&
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
return false;
}
if (op.getMaxNumImpreciseAcc() < 32 &&
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
return false;
}
}
if (aElemTy.isF32() && bElemTy.isF32()) {
return op.getInputPrecision() == InputPrecision::TF32 && version >= 2;
}
return supportMMA(op.getA(), version) && supportMMA(op.getB(), version);
}
bool supportMMA(Value value, int version) {
assert((version == 1 || version == 2 || version == 3) &&
"Unexpected MMA layout version found");
auto elemTy =
cast<triton::gpu::TensorOrMemDesc>(value.getType()).getElementType();
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
Float8E4M3FNUZType>(elemTy);
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
((elemTy.isF32() || elemTy.isF64()) && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);
}
LinearLayout minimalCvtLayout(Type srcTy_, Type dstTy_) {
auto srcTy = cast<triton::gpu::TensorOrMemDesc>(srcTy_);
auto dstTy = cast<triton::gpu::TensorOrMemDesc>(dstTy_);
LinearLayout srcLayout = toLinearLayout(srcTy);
LinearLayout dstLayout = toLinearLayout(dstTy);
auto sDims = to_vector(srcLayout.getInDimNames());
auto dDims = to_vector(dstLayout.getInDimNames());
SmallVector<StringAttr> dims;
for (int i = 0; i < std::min(sDims.size(), dDims.size()); ++i) {
auto srcDim = sDims[sDims.size() - i - 1];
auto dstDim = dDims[dDims.size() - i - 1];
if (srcDim != dstDim) {
break;
}
dims.push_back(srcDim);
}
auto comp = dstLayout.invertAndCompose(srcLayout);
for (auto dim : dims) {
auto quotient = comp.quotient(dim);
if (!quotient.has_value()) {
break;
}
comp = *quotient;
}
return comp;
}
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) {
auto layout = minimalCvtLayout(srcTy, dstTy);
MLIRContext *ctx = srcTy.getContext();
auto kRegister = StringAttr::get(ctx, "register");
auto outDims = to_vector(layout.getOutDimNames());
return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister});
}
bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
auto layout = minimalCvtLayout(srcTy, dstTy);
MLIRContext *ctx = srcTy.getContext();
auto kRegister = StringAttr::get(ctx, "register");
auto kLane = StringAttr::get(ctx, "lane");
if (to_vector(layout.getOutDimNames()) ==
SmallVector<StringAttr, 2>{kRegister, kLane}) {
auto factors = getWarpLayoutConvertDecomposition(srcTy, dstTy, 32);
return (factors.mixedTranspositions.size() < 2);
}
return false;
}
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
return !cvtReordersRegisters(srcTy, dstTy) &&
!cvtNeedsWarpShuffle(srcTy, dstTy);
}
namespace {
struct DFSSubgraphState {
DFSSubgraphState() : set(), deque() {}
DenseSet<Operation *> set;
std::deque<Operation *> deque;
bool push_back(Operation *op) {
if (set.insert(op).second) {
deque.push_back(op);
return true;
}
return false;
}
Operation *pop_front() {
Operation *op = deque.front();
deque.pop_front();
set.erase(op);
return op;
}
bool empty() { return deque.empty(); }
};
struct DFSState {
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
const SetVector<Operation *> &toSort;
SmallVector<Operation *, 16> topologicalCounts;
DenseSet<Operation *> seen;
void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph,
SmallVector<Operation *, 4> &readyQueue) {
bool ready = true;
for (Value operand : op->getOperands()) {
auto def = operand.getDefiningOp();
if (def && !seen.count(def)) {
subGraph.push_back(def);
ready = false;
}
}
Operation *parent = op->getParentOp();
while (parent) {
if (!seen.count(parent)) {
subGraph.push_back(parent);
ready = false;
}
parent = parent->getParentOp();
}
if (ready)
readyQueue.push_back(op);
}
};
void dfsPostorder(Operation *root, DFSState *state) {
DFSSubgraphState subGraph;
subGraph.push_back(root);
SmallVector<Operation *> ops;
while (!subGraph.empty()) {
SmallVector<Operation *, 4> readyQueue;
auto *current = subGraph.pop_front();
state->addToReadyQueue(current, subGraph, readyQueue);
while (!readyQueue.empty()) {
Operation *current = readyQueue.pop_back_val();
if (!state->seen.insert(current).second)
continue;
ops.push_back(current);
for (Value result : current->getResults()) {
for (Operation *op : result.getUsers())
state->addToReadyQueue(op, subGraph, readyQueue);
}
for (Region ®ion : current->getRegions()) {
for (Operation &op : region.getOps())
state->addToReadyQueue(&op, subGraph, readyQueue);
}
}
}
for (Operation *op : llvm::reverse(ops)) {
if (state->toSort.count(op) > 0)
state->topologicalCounts.push_back(op);
}
}
}
SetVector<Operation *>
multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
if (toSort.empty()) {
return toSort;
}
DFSState state(toSort);
for (auto *s : toSort) {
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
dfsPostorder(s, &state);
}
SetVector<Operation *> res;
for (auto it = state.topologicalCounts.rbegin(),
eit = state.topologicalCounts.rend();
it != eit; ++it) {
res.insert(*it);
}
return res;
}
SetVector<Operation *> multiRootGetSlice(Operation *op,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter) {
SetVector<Operation *> slice;
slice.insert(op);
unsigned currentIndex = 0;
SetVector<Operation *> backwardSlice;
SetVector<Operation *> forwardSlice;
while (currentIndex != slice.size()) {
auto *currentOp = (slice)[currentIndex];
backwardSlice.clear();
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = backwardFilter;
(void)getBackwardSlice(currentOp, &backwardSlice, opt);
slice.insert(backwardSlice.begin(), backwardSlice.end());
forwardSlice.clear();
getForwardSlice(currentOp, &forwardSlice, forwardFilter);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
return multiRootTopologicalSort(slice);
}
namespace {
class ConstantAnalysis : public DataFlowAnalysis {
public:
using DataFlowAnalysis::DataFlowAnalysis;
LogicalResult initialize(Operation *top) override {
WalkResult result = top->walk([&](Operation *op) {
ProgramPoint programPoint(op);
if (failed(visit(&programPoint)))
return WalkResult::interrupt();
return WalkResult::advance();
});
return success(!result.wasInterrupted());
}
LogicalResult visit(ProgramPoint *point) override {
Operation *op = point->getOperation();
Attribute value;
if (matchPattern(op, m_Constant(&value))) {
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
op->getResult(0));
propagateIfChanged(constant, constant->join(dataflow::ConstantValue(
value, op->getDialect())));
return success();
}
setAllToUnknownConstants(op->getResults());
for (Region ®ion : op->getRegions()) {
for (Block &block : region.getBlocks())
setAllToUnknownConstants(block.getArguments());
}
return success();
}
private:
void setAllToUnknownConstants(ValueRange values) {
dataflow::ConstantValue unknownConstant(nullptr, nullptr);
for (Value value : values) {
auto *constant =
getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
propagateIfChanged(constant, constant->join(unknownConstant));
}
}
};
}
std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
auto solver = std::make_unique<DataFlowSolver>();
solver->load<dataflow::DeadCodeAnalysis>();
solver->load<ConstantAnalysis>();
return solver;
}
bool isCvtWarpSync(const triton::LinearLayout &srcLayout,
const triton::LinearLayout &dstLayout) {
auto *ctx = srcLayout.getInDimNames().begin()->getContext();
auto comp = dstLayout.invertAndCompose(srcLayout);
auto kWarp = StringAttr::get(ctx, "warp");
return comp.isTrivialOver(kWarp) &&
srcLayout.getFreeVariableMasks()[kWarp] == 0 &&
dstLayout.getFreeVariableMasks()[kWarp] == 0;
}
}