#include "triton/Analysis/Utility.h"
#include <deque>
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/Sys/GetEnv.hpp"
namespace mlir {
namespace {
using namespace triton;
using namespace triton::gpu;
int getParentAxis(Attribute layout, int axis) {
if (auto sliceEncoding = dyn_cast<SliceEncodingAttr>(layout)) {
axis = axis < sliceEncoding.getDim() ? axis : axis + 1;
return getParentAxis(sliceEncoding.getParent(), axis);
}
return axis;
}
SmallVector<unsigned> getParentOrder(Attribute layout) {
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
return getParentOrder(sliceEncoding.getParent());
}
return getThreadOrder(layout);
}
}
bool ReduceOpHelper::isReductionOnLayoutFastAxis() {
return getParentAxis(getSrcLayout(), axis) ==
getParentOrder(getSrcLayout())[0];
}
SmallVector<unsigned> ReduceOpHelper::getOrderWithAxisAtBeginning() {
auto srcLayout = getSrcLayout();
auto order = getOrder(srcLayout);
auto it = std::find(order.begin(), order.end(), axis);
order.erase(it);
order.insert(order.begin(), axis);
return order;
}
unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
auto srcLayout = getSrcLayout();
if (isReductionOnLayoutFastAxis()) {
return 1;
}
unsigned threadOffset = 1;
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
auto parentLayout = sliceLayout.getParent();
auto threadsPerWarp = getThreadsPerWarp(parentLayout);
threadOffset = threadsPerWarp[sliceLayout.getDim()];
} else {
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
auto order = getThreadOrder(srcLayout);
for (unsigned i = 0; i < order.size(); i++) {
if (order[i] == axis)
break;
threadOffset *= threadsPerWarp[order[i]];
}
}
return threadOffset;
}
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout) {
unsigned numCTAs = getNumCTAs(srcLayout);
assert(numCTAs == getNumCTAs(dstLayout) &&
"Invalid layout conversion: the numbers of CTAs of src and dst "
"layouts are different");
if (numCTAs == 1)
return false;
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
auto dim = sliceLayout.getDim();
auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent());
if (CTAsPerCGA[dim] != 1)
llvm::report_fatal_error("Layout conversion to be implemented");
}
if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(dstLayout)) {
auto dim = sliceLayout.getDim();
auto CTAsPerCGA = getCTAsPerCGA(sliceLayout.getParent());
if (CTAsPerCGA[dim] != 1)
return true;
}
auto srcCTALayout = getCTALayout(srcLayout);
auto dstCTALayout = getCTALayout(dstLayout);
if (srcCTALayout == dstCTALayout)
return false;
return true;
}
unsigned ReduceOpHelper::getInterWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSize();
return std::min(srcReduceDimSize / sizeIntraWarps,
getWarpsPerCTA(getSrcLayout())[axis]);
}
unsigned ReduceOpHelper::getIntraWarpSize() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize, getThreadsPerWarp(getSrcLayout())[axis]);
}
unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData();
return std::min(
srcReduceDimSize / sizeIntraWarps,
getWarpsPerCTAWithUniqueData(getSrcLayout(), getSrcShape())[axis]);
}
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned elementPerThreads =
getUniqueContigPerThread(getSrcLayout(), getSrcShape())[axis];
return std::min(
srcReduceDimSize / elementPerThreads,
getThreadsPerWarpWithUniqueData(getSrcLayout(), getSrcShape())[axis]);
}
unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = getSrcLayout();
auto srcShape = getSrcShape();
return getThreadsPerWarpWithUniqueData(srcLayout, srcShape)[axis] *
getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis];
}
bool ReduceOpHelper::isWarpSynchronous() {
auto srcLayout = getSrcLayout();
auto srcShape = getSrcShape();
return getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis] == 1;
}
SmallVector<unsigned> ReduceOpHelper::getScratchRepShape() {
SmallVector<unsigned> smemShape;
if (isWarpSynchronous())
return {0, 0};
smemShape = convertType<unsigned>(getSrcShape());
smemShape[axis] = getInterWarpSizeWithUniqueData();
return smemShape;
}
unsigned ReduceOpHelper::getScratchSizeInBytes() {
auto smemShape = getScratchRepShape();
auto elems = product<unsigned>(smemShape);
unsigned bytesPerElem = 0;
for (const auto &ty : srcElementTypes) {
bytesPerElem += ceil<unsigned>(ty.getIntOrFloatBitWidth(), 8);
}
return bytesPerElem * elems;
}
bool ReduceOpHelper::isReduceWithinCTA() {
auto axis = getAxis();
auto srcLayout = getSrcLayout();
auto CTASplitNum = getCTASplitNum(srcLayout);
assert(axis < CTASplitNum.size());
return CTASplitNum[axis] == 1;
}
bool ReduceOpHelper::isSupportedLayout() {
if (!isReduceWithinCTA()) {
return false;
}
auto srcLayout = getSrcLayout();
if (isa<BlockedEncodingAttr>(srcLayout)) {
return true;
}
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(srcLayout)) {
return mmaLayout.supportReduction();
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(srcLayout)) {
return true;
}
return false;
}
unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
return getEncoding().getSizePerThread()[getAxis()];
}
unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
sizePerThreads[getAxis()] = 1;
return product<unsigned>(sizePerThreads);
}
Region &ScanLoweringHelper::getCombineOp() { return scanOp.getCombineOp(); }
unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() {
return getThreadsPerWarp(getEncoding())[getAxis()];
}
unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() {
return getThreadsPerWarpWithUniqueData(getEncoding(), getShape())[getAxis()];
}
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() {
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
threadsPerWarp[getAxis()] = 1;
return product<unsigned>(threadsPerWarp);
}
unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp();
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
warpsPerCTA[getAxis()] = 1;
unsigned numParallelWarpsPerCTA = product<unsigned>(warpsPerCTA);
return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
}
unsigned ScanLoweringHelper::getAxisNumWarps() {
return getWarpsPerCTA(getEncoding())[getAxis()];
}
unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() {
return getWarpsPerCTAWithUniqueData(getEncoding(), getShape())[getAxis()];
}
unsigned ScanLoweringHelper::getAxisNumBlocks() {
auto sizePerThreads = getSizePerThread(getEncoding());
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
unsigned axis = getAxis();
return ceil<unsigned>(
getShape()[axis],
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
}
unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
auto sizePerThreads = getSizePerThread(getEncoding());
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
unsigned axis = getAxis();
unsigned numBlocks = 1;
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
if (i == axis)
continue;
numBlocks *=
ceil<unsigned>(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] *
warpsPerCTA[i]));
}
return numBlocks;
}
bool ScanLoweringHelper::isSupported() {
if (!isa<BlockedEncodingAttr>(getEncoding()))
return false;
return true;
}
unsigned ScanLoweringHelper::getScratchSizeInElems() {
auto mod = scanOp->getParentOfType<ModuleOp>();
unsigned numWarps = TritonGPUDialect::getNumWarps(mod);
unsigned numNonAxisElementsPerWarp =
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
unsigned numElements = numWarps * numNonAxisElementsPerWarp *
getAxisNumBlocks() * getNonAxisNumBlocks();
return numElements;
}
unsigned ScanLoweringHelper::getScratchSizeInBytes() {
unsigned axisNumWarps = getAxisNumWarpsWithUniqueData();
if (axisNumWarps == 1)
return 0;
unsigned elementSizeInBytes = 0;
for (const auto &ty : srcElementTypes) {
elementSizeInBytes += ceil<unsigned>(ty.getIntOrFloatBitWidth(), 8);
}
return elementSizeInBytes * getScratchSizeInElems();
}
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>>
getReshapeDecomposition(ArrayRef<int64_t> srcShape,
ArrayRef<int64_t> dstShape) {
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>> ret;
if (srcShape.empty()) {
assert(dstShape.empty());
return ret;
}
ret.push_back({});
int srcIdx = 0;
int dstIdx = 0;
int srcNElems = 1;
int dstNElems = 1;
while (srcIdx < srcShape.size() || dstIdx < dstShape.size()) {
if (srcNElems < dstNElems ||
(srcIdx < srcShape.size() && srcNElems == 1) ||
(srcIdx < srcShape.size() && srcShape[srcIdx] == 1)) {
assert(srcIdx < srcShape.size());
srcNElems *= srcShape[srcIdx];
ret.back().first.push_back(srcIdx);
srcIdx++;
} else if (dstNElems < srcNElems ||
(dstIdx < dstShape.size() && dstShape[dstIdx] == 1)) {
assert(dstIdx < dstShape.size());
dstNElems *= dstShape[dstIdx];
ret.back().second.push_back(dstIdx);
dstIdx++;
} else {
ret.push_back({});
srcNElems = 1;
dstNElems = 1;
}
}
return ret;
}
BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
return cast<BlockedEncodingAttr>(srcEncoding);
}
unsigned ScanLoweringHelper::getAxisElementStride() {
auto order = getOrder(getEncoding());
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= getContigPerThread(getEncoding())[dim];
}
llvm_unreachable("Axis not found in order");
}
unsigned ScanLoweringHelper::getAxisThreadStride() {
auto order = getOrder(getEncoding());
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= getEncoding().getThreadsPerWarp()[dim];
}
llvm_unreachable("Axis not found in order");
}
unsigned ScanLoweringHelper::getAxisBlockStride() {
auto order = getOrder(getEncoding());
unsigned stride = 1;
auto sizePerThreads = getSizePerThread(getEncoding());
auto threadsPerWarp = getThreadsPerWarp(getEncoding());
auto warpsPerCTA = getWarpsPerCTA(getEncoding());
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= ceil<unsigned int>(getShape()[dim], sizePerThreads[dim] *
threadsPerWarp[dim] *
warpsPerCTA[dim]);
}
llvm_unreachable("Axis not found in order");
}
unsigned getNumScratchElements(ArrayRef<unsigned> shape) {
if (shape.empty())
return 0;
return product<unsigned>(shape);
}
static bool supportMFMAGranularity(int m, int n, int k) {
const static std::pair<int, int> mfmaTypes[2] = {{32, 8}, {16, 16}};
for (const auto &mfmaType : mfmaTypes) {
auto [granularityMN, granularityK] = mfmaType;
if (m % granularityMN != 0 || n % granularityMN != 0)
continue;
if (k % granularityK != 0)
continue;
return true;
}
return false;
}
bool supportMFMATypes(Type a, Type b) {
if (a.getIntOrFloatBitWidth() != b.getIntOrFloatBitWidth())
return false;
auto F8E5M2 = TypeID::get<Float8E5M2Type>();
auto F8E4M3FN = TypeID::get<Float8E4M3FNType>();
auto F8E4M3FNUZ = TypeID::get<Float8E4M3FNUZType>();
auto F8E5M2FNUZ = TypeID::get<Float8E5M2FNUZType>();
auto F16 = TypeID::get<Float16Type>();
auto BF16 = TypeID::get<BFloat16Type>();
auto F32 = TypeID::get<Float32Type>();
auto Int = TypeID::get<IntegerType>();
DenseSet<std::pair<TypeID, TypeID>> supportedTypes = {
{F32, F32},
{F16, F16},
{BF16, BF16},
{F8E5M2, F8E5M2},
{F8E4M3FN, F8E4M3FN},
{F8E4M3FNUZ, F8E4M3FNUZ},
{F8E4M3FNUZ, F8E5M2FNUZ},
{F8E5M2FNUZ, F8E4M3FNUZ},
{F8E5M2FNUZ, F8E5M2FNUZ},
{Int, Int}};
if (!supportedTypes.contains({a.getTypeID(), b.getTypeID()}))
return false;
if (a.isIntOrIndex() && a.getIntOrFloatBitWidth() != 8)
return false;
return true;
}
bool supportMFMA(triton::DotOp op) {
auto aTy = cast<RankedTensorType>(op.getA().getType());
auto bTy = cast<RankedTensorType>(op.getB().getType());
auto aElemTy = aTy.getElementType();
auto bElemTy = bTy.getElementType();
if (!supportMFMATypes(aElemTy, bElemTy))
return false;
auto aShape = aTy.getShape();
auto bShape = bTy.getShape();
auto rank = aShape.size();
assert(bShape.size() == rank);
auto M = aShape[rank - 2];
auto N = bShape[rank - 1];
auto K = aShape[rank - 1];
assert(K == bShape[rank - 2]);
if (!supportMFMAGranularity(M, N, K))
return false;
return true;
}
bool supportMMA(triton::DotOp op, int version) {
auto aElemTy = op.getA().getType().getElementType();
auto bElemTy = op.getB().getType().getElementType();
if (version == 3) {
if (triton::tools::getBoolEnv("DISABLE_MMA_V3"))
return false;
auto retType = op.getType();
RankedTensorType typeA = op.getA().getType();
int k = typeA.getShape().back();
if (k < 256 / aElemTy.getIntOrFloatBitWidth())
return false;
auto retShapePerCTA = getShapePerCTA(retType);
auto rank = retShapePerCTA.size();
auto mod = op->getParentOfType<ModuleOp>();
int numWarps = TritonGPUDialect::getNumWarps(mod);
if (rank == 3)
return false;
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
retShapePerCTA[rank - 1] % 8 == 0 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN() ||
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
aElemTy.isF32()))) {
return false;
}
if (op.getMaxNumImpreciseAcc() < 32 &&
(aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FN()) &&
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
return false;
}
}
if (aElemTy.isF32() && bElemTy.isF32()) {
return op.getInputPrecision() == InputPrecision::TF32 && version >= 2;
}
return supportMMA(op.getA(), version) && supportMMA(op.getB(), version);
}
bool supportMMA(Value value, int version) {
assert((version == 1 || version == 2 || version == 3) &&
"Unexpected MMA layout version found");
auto elemTy = cast<TensorOrMemDesc>(value.getType()).getElementType();
bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() ||
elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ();
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
(elemTy.isF32() && version >= 2) ||
(elemTy.isInteger(8) && version >= 2);
}
bool isBlockedToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) {
auto blockedLayout = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (blockedLayout == nullptr || dotOperandLayout == nullptr)
return false;
auto parentLayout =
dyn_cast<BlockedEncodingAttr>(dotOperandLayout.getParent());
if (parentLayout == nullptr)
return false;
auto opShape = srcTy.getShape();
auto rank = opShape.size();
int kDim = dotOperandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
int nonKDim = dotOperandLayout.getOpIdx() == 0 ? rank - 2 : rank - 1;
auto ctaLayout = blockedLayout.getCTALayout();
bool ctaLayoutCompatible =
ctaLayout.getCTASplitNum()[kDim] == 1 &&
blockedLayout.getCTALayout() == parentLayout.getCTALayout();
bool threadHoldsWholeKDim =
blockedLayout.getSizePerThread()[kDim] == opShape[kDim];
bool nonKDimCompatible =
blockedLayout.getOrder() == parentLayout.getOrder() &&
blockedLayout.getSizePerThread()[nonKDim] ==
parentLayout.getSizePerThread()[nonKDim] &&
blockedLayout.getThreadsPerWarp()[nonKDim] ==
parentLayout.getThreadsPerWarp()[nonKDim] &&
blockedLayout.getWarpsPerCTA()[nonKDim] ==
parentLayout.getWarpsPerCTA()[nonKDim];
bool matrixDimsCompatible =
ctaLayoutCompatible && threadHoldsWholeKDim && nonKDimCompatible;
if (rank == 2)
return matrixDimsCompatible;
assert(rank == 3);
bool bDimCompatible =
blockedLayout.getSizePerThread()[0] ==
parentLayout.getSizePerThread()[0] &&
blockedLayout.getThreadsPerWarp()[0] ==
parentLayout.getThreadsPerWarp()[0] &&
blockedLayout.getWarpsPerCTA()[0] == parentLayout.getWarpsPerCTA()[0];
return matrixDimsCompatible && bDimCompatible;
}
bool isMfmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (mfmaLayout == nullptr || dotOperandLayout == nullptr)
return false;
return mfmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
dotOperandLayout.getKWidth() == getContigPerThread(mfmaLayout)[1] &&
dotOperandLayout.getParent() == mfmaLayout &&
(mfmaLayout.getMDim() == 32 || mfmaLayout.getMDim() == 16) &&
(srcTy.getElementType().isF16() || srcTy.getElementType().isBF16());
}
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
if (!mmaLayout || !dotOperandLayout) {
return false;
}
int elementTypeSize = srcTy.getElementType().getIntOrFloatBitWidth();
auto parentTy = RankedTensorType::get(
srcTy.getShape(), srcTy.getElementType(), dotOperandLayout.getParent());
auto ans = mmaLayout.getVersionMajor() == 3 &&
dotOperandLayout.getOpIdx() == 0 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
!cvtNeedsSharedMemory(parentTy, srcTy) &&
(elementTypeSize == 16 || elementTypeSize == 8);
return ans;
}
std::optional<LinearLayout> minimalCvtLayout(RankedTensorType srcTy,
RankedTensorType dstTy) {
MLIRContext *ctx = srcTy.getContext();
std::optional<LinearLayout> srcLayout =
toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
std::optional<LinearLayout> dstLayout =
toLinearLayout(dstTy.getShape(), dstTy.getEncoding());
if (!(srcLayout.has_value() && dstLayout.has_value()))
return std::nullopt;
LinearLayout comp = dstLayout->invertAndCompose(*srcLayout);
auto dims = SmallVector<StringRef>{"block", "warp", "lane", "register"};
for (auto dim : dims) {
auto quotient = comp.quotient(StringAttr::get(ctx, dim));
if (!quotient.has_value()) {
break;
}
comp = *quotient;
}
return comp;
}
bool cvtReordersRegisters(RankedTensorType srcTy, RankedTensorType dstTy) {
auto layout = minimalCvtLayout(srcTy, dstTy);
MLIRContext *ctx = srcTy.getContext();
if (!layout.has_value()) {
return false;
}
auto kRegister = StringAttr::get(ctx, "register");
auto outDims = llvm::to_vector(layout->getOutDimNames());
return outDims.empty() || ArrayRef(outDims) == ArrayRef({kRegister});
}
bool cvtNeedsWarpShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
auto layout = minimalCvtLayout(srcTy, dstTy);
MLIRContext *ctx = srcTy.getContext();
if (!layout.has_value()) {
return false;
}
auto kRegister = StringAttr::get(ctx, "register");
auto kLane = StringAttr::get(ctx, "lane");
return llvm::to_vector(layout->getOutDimNames()) ==
llvm::SmallVector<StringAttr, 2>{kRegister, kLane};
}
bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
return !cvtReordersRegisters(srcTy, dstTy) &&
!isBlockedToDotShortcut(srcTy, dstTy) &&
!isMmaToDotShortcut(srcTy, dstTy) &&
!isMfmaToDotShortcut(srcTy, dstTy);
}
bool atomicNeedsSharedMemory(Value value) {
auto type = value.getType();
if (isa<RankedTensorType>(type) || value.use_empty())
return false;
return true;
}
bool isMmaToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy) {
if (matchMmaV3AndDotOperandLayout(srcTy, dstTy))
return true;
auto mmaLayout = dyn_cast<NvidiaMmaEncodingAttr>(srcTy.getEncoding());
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
return mmaLayout && dotOperandLayout && mmaLayout.getVersionMajor() == 2 &&
mmaLayout.getWarpsPerCTA()[1] == 1 &&
dotOperandLayout.getOpIdx() == 0 &&
dotOperandLayout.getParent() == mmaLayout &&
!srcTy.getElementType().isF32();
}
namespace {
struct DFSSubgraphState {
DFSSubgraphState() : set(), deque() {}
DenseSet<Operation *> set;
std::deque<Operation *> deque;
bool push_back(Operation *op) {
if (set.insert(op).second) {
deque.push_back(op);
return true;
}
return false;
}
Operation *pop_front() {
Operation *op = deque.front();
deque.pop_front();
set.erase(op);
return op;
}
bool empty() { return deque.empty(); }
};
struct DFSState {
DFSState(const SetVector<Operation *> &set) : toSort(set), seen() {}
const SetVector<Operation *> &toSort;
SmallVector<Operation *, 16> topologicalCounts;
DenseSet<Operation *> seen;
void addToReadyQueue(Operation *op, DFSSubgraphState &subGraph,
SmallVector<Operation *, 4> &readyQueue) {
bool ready = true;
for (Value operand : op->getOperands()) {
auto def = operand.getDefiningOp();
if (def && !seen.count(def)) {
subGraph.push_back(def);
ready = false;
}
}
Operation *parent = op->getParentOp();
while (parent) {
if (!seen.count(parent)) {
subGraph.push_back(parent);
ready = false;
}
parent = parent->getParentOp();
}
if (ready)
readyQueue.push_back(op);
}
};
void dfsPostorder(Operation *root, DFSState *state) {
DFSSubgraphState subGraph;
subGraph.push_back(root);
SmallVector<Operation *> ops;
while (!subGraph.empty()) {
SmallVector<Operation *, 4> readyQueue;
auto *current = subGraph.pop_front();
state->addToReadyQueue(current, subGraph, readyQueue);
while (!readyQueue.empty()) {
Operation *current = readyQueue.pop_back_val();
if (!state->seen.insert(current).second)
continue;
ops.push_back(current);
for (Value result : current->getResults()) {
for (Operation *op : result.getUsers())
state->addToReadyQueue(op, subGraph, readyQueue);
}
for (Region ®ion : current->getRegions()) {
for (Operation &op : region.getOps())
state->addToReadyQueue(&op, subGraph, readyQueue);
}
}
}
for (Operation *op : llvm::reverse(ops)) {
if (state->toSort.count(op) > 0)
state->topologicalCounts.push_back(op);
}
}
}
SetVector<Operation *>
multiRootTopologicalSort(const SetVector<Operation *> &toSort) {
if (toSort.empty()) {
return toSort;
}
DFSState state(toSort);
for (auto *s : toSort) {
assert(toSort.count(s) == 1 && "NYI: multi-sets not supported");
dfsPostorder(s, &state);
}
SetVector<Operation *> res;
for (auto it = state.topologicalCounts.rbegin(),
eit = state.topologicalCounts.rend();
it != eit; ++it) {
res.insert(*it);
}
return res;
}
SetVector<Operation *> multiRootGetSlice(Operation *op,
TransitiveFilter backwardFilter,
TransitiveFilter forwardFilter) {
SetVector<Operation *> slice;
slice.insert(op);
unsigned currentIndex = 0;
SetVector<Operation *> backwardSlice;
SetVector<Operation *> forwardSlice;
while (currentIndex != slice.size()) {
auto *currentOp = (slice)[currentIndex];
backwardSlice.clear();
BackwardSliceOptions opt;
opt.omitBlockArguments = true;
opt.filter = backwardFilter;
getBackwardSlice(currentOp, &backwardSlice, opt);
slice.insert(backwardSlice.begin(), backwardSlice.end());
forwardSlice.clear();
getForwardSlice(currentOp, &forwardSlice, forwardFilter);
slice.insert(forwardSlice.begin(), forwardSlice.end());
++currentIndex;
}
return multiRootTopologicalSort(slice);
}
namespace {
class ConstantAnalysis : public DataFlowAnalysis {
public:
using DataFlowAnalysis::DataFlowAnalysis;
LogicalResult initialize(Operation *top) override {
WalkResult result = top->walk([&](Operation *op) {
ProgramPoint programPoint(op);
if (failed(visit(&programPoint)))
return WalkResult::interrupt();
return WalkResult::advance();
});
return success(!result.wasInterrupted());
}
LogicalResult visit(ProgramPoint *point) override {
Operation *op = point->getOperation();
Attribute value;
if (matchPattern(op, m_Constant(&value))) {
auto *constant = getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(
op->getResult(0));
propagateIfChanged(constant, constant->join(dataflow::ConstantValue(
value, op->getDialect())));
return success();
}
setAllToUnknownConstants(op->getResults());
for (Region ®ion : op->getRegions()) {
for (Block &block : region.getBlocks())
setAllToUnknownConstants(block.getArguments());
}
return success();
}
private:
void setAllToUnknownConstants(ValueRange values) {
dataflow::ConstantValue unknownConstant(nullptr, nullptr);
for (Value value : values) {
auto *constant =
getOrCreate<dataflow::Lattice<dataflow::ConstantValue>>(value);
propagateIfChanged(constant, constant->join(unknownConstant));
}
}
};
}
std::unique_ptr<DataFlowSolver> createDataFlowSolver() {
auto solver = std::make_unique<DataFlowSolver>();
solver->load<dataflow::DeadCodeAnalysis>();
solver->load<ConstantAnalysis>();
return solver;
}
static MakeTensorPtrOp getMakeTensorPtrOpImpl(Operation *op, Value v) {
if (auto makeTensorPtrOp = dyn_cast<MakeTensorPtrOp>(op)) {
return makeTensorPtrOp;
}
if (auto advanceOp = dyn_cast<AdvanceOp>(op)) {
return getMakeTensorPtrOp(advanceOp.getPtr());
}
if (auto branch = dyn_cast<RegionBranchOpInterface>(op)) {
auto idx = cast<OpResult>(v).getResultNumber();
llvm::SmallVector<scf::YieldOp> yieldOps;
op->walk([&](Operation *op) {
if (auto yieldOp = dyn_cast<scf::YieldOp>(op))
yieldOps.push_back(yieldOp);
});
Value newValue = yieldOps[0].getOperands()[idx];
return getMakeTensorPtrOp(newValue);
}
llvm_unreachable("Unable to getMakeTensorPtr()");
}
MakeTensorPtrOp getMakeTensorPtrOp(Value v) {
using BranchOps = llvm::SetVector<std::pair<Operation *, int>>;
llvm::DenseMap<Block *, BranchOps> blockToCFOps;
auto moduleOp =
v.getParentBlock()->getParentOp()->getParentOfType<ModuleOp>();
moduleOp.walk([&](Operation *op) {
if (auto br = dyn_cast<cf::BranchOp>(op)) {
Block *block = br.getDest();
blockToCFOps[block].insert({op, -1});
}
if (auto condBr = dyn_cast<cf::CondBranchOp>(op)) {
Block *blockT = condBr.getTrueDest();
Block *blockF = condBr.getFalseDest();
blockToCFOps[blockT].insert({condBr, 1});
blockToCFOps[blockF].insert({condBr, 0});
}
});
if (Operation *definingOp = v.getDefiningOp())
return getMakeTensorPtrOpImpl(definingOp, v);
BlockArgument arg = cast<BlockArgument>(v);
unsigned argNum = arg.getArgNumber();
Operation *argOwner = arg.getOwner()->getParentOp();
if (auto forOp = dyn_cast<scf::ForOp>(argOwner))
return getMakeTensorPtrOp(
forOp.getOperand(argNum + forOp.getNumControlOperands() - 1));
if (auto funcOp = dyn_cast<FunctionOpInterface>(argOwner)) {
Block *block = arg.getOwner();
Operation *op;
int tOrF;
std::tie(op, tOrF) = blockToCFOps[block][0];
if (auto br = dyn_cast<cf::BranchOp>(op))
return getMakeTensorPtrOp(br.getDestOperands()[argNum]);
if (auto condBr = dyn_cast<cf::CondBranchOp>(op))
return getMakeTensorPtrOp(tOrF ? condBr.getTrueDestOperands()[argNum]
: condBr.getFalseDestOperands()[argNum]);
return getMakeTensorPtrOp(argOwner->getOperand(argNum));
}
llvm_unreachable("Unable to getMakeTensorPtr()");
}
}