#include <vector>
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
using mlir::triton::nvidia_gpu::TensorMemoryEncodingAttr;
using mlir::triton::nvidia_gpu::TensorMemoryScalesEncodingAttr;
namespace mlir::triton::gpu {
namespace {
#define S(v) StringAttr::get(ctx, (v))
SmallVector<unsigned> getDefaultMmaOrder(MmaEncodingTrait layout) {
auto rank = layout.getRepOrderForOperand(0).size();
return getMatrixOrder(rank, true);
}
SmallVector<StringAttr> permuteDimNames(const SmallVector<StringAttr> &names,
const SmallVector<unsigned> &order) {
assert(names.size() == order.size());
SmallVector<StringAttr> ret;
for (unsigned i : order) {
ret.push_back(names[i]);
}
return ret;
}
LinearLayout makeCgaLayout(CTALayoutAttr layout) {
MLIRContext *ctx = layout.getContext();
StringAttr kBlock = S("block");
int rank = layout.getCTAOrder().size();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
LinearLayout ret = LinearLayout::empty();
for (int i = 0; i < rank; i++) {
int dim = layout.getCTAOrder()[i];
int split = layout.getCTASplitNum()[dim];
int ctas = layout.getCTAsPerCGA()[dim];
assert(ctas % split == 0);
ret *= LinearLayout::identity1D(split, kBlock, outDimNames[dim]) *
LinearLayout::zeros1D(ctas / split, kBlock, outDimNames[dim]);
}
return ret.transposeOuts(outDimNames);
}
LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
CTALayoutAttr cgaLayoutAttr,
ArrayRef<int64_t> shape) {
int rank = shape.size();
assert(ctaLayout.getNumOutDims() == rank);
assert(cgaLayoutAttr.getCTAOrder().size() == rank);
MLIRContext *ctx = cgaLayoutAttr.getContext();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
llvm::SmallDenseMap<StringAttr, int64_t> labeledShape;
for (auto [dim, size] : llvm::zip(outDimNames, shape)) {
labeledShape[dim] = size;
}
LinearLayout cgaLayout =
ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape)
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
llvm::SmallDenseMap<StringAttr, int64_t> ctaShape;
assert(llvm::to_vector(ctaLayout.getOutDimNames()) ==
llvm::to_vector(cgaLayout.getOutDimNames()));
for (auto dim : ctaLayout.getOutDimNames()) {
ctaShape[dim] =
std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim));
}
ctaLayout = ensureLayoutNotSmallerThan(ctaLayout, ctaShape);
ctaLayout = ensureLayoutNotLargerThan(ctaLayout, ctaShape);
LinearLayout ret = (ctaLayout * cgaLayout).transposeOuts(outDimNames);
for (auto dim : ret.getOutDimNames()) {
assert(ret.getOutDimSize(dim) == labeledShape[dim]);
}
return ret;
}
LinearLayout swizzledSharedToLinearLayout(ArrayRef<int64_t> shape,
SwizzledSharedEncodingAttr shared) {
MLIRContext *ctx = shared.getContext();
auto shapePerCTA = getShapePerCTA(shared, shape);
int rank = shape.size();
if (rank == 1) {
return combineCtaCgaWithShape(
LinearLayout::identity1D(shapePerCTA[0], S("offset"), S("dim0")),
shared.getCTALayout(), shape);
}
auto outDimNames = standardOutDimNames(ctx, rank);
assert(shape.size() >= 2);
int colDim = shared.getOrder()[0];
int rowDim = shared.getOrder()[1];
int numCols = shapePerCTA[colDim];
int numRows = shapePerCTA[rowDim];
StringAttr colDimName = outDimNames[colDim];
StringAttr rowDimName = outDimNames[rowDim];
std::vector<std::vector<int>> bases2D;
for (int col = 1; col < numCols; col *= 2) {
bases2D.push_back({0, col});
}
for (int row = 1; row < numRows; row *= 2) {
int vec = shared.getVec();
int perPhase = shared.getPerPhase();
int maxPhase = shared.getMaxPhase();
bases2D.push_back({row, (vec * ((row / perPhase) % maxPhase)) % numCols});
}
LinearLayout ctaLayout =
LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName});
for (int i = 2; i < rank; i++) {
int dim = shared.getOrder()[i];
ctaLayout *= LinearLayout::identity1D(shapePerCTA[dim], S("offset"),
outDimNames[dim]);
}
return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape);
}
LinearLayout
sharedToLinearLayoutAMDRotating(ArrayRef<int64_t> shape,
AMDRotatingSharedEncodingAttr shared) {
MLIRContext *ctx = shared.getContext();
auto shapePerCTA = getShapePerCTA(shared, shape);
int rank = shape.size();
if (rank == 1) {
return combineCtaCgaWithShape(
LinearLayout::identity1D(shapePerCTA[0], S("offset"), S("dim0")),
shared.getCTALayout(), shape);
}
auto outDimNames = standardOutDimNames(ctx, rank);
assert(shape.size() >= 2);
int colDim = shared.getOrder()[0];
int rowDim = shared.getOrder()[1];
int numCols = shape[colDim];
int numRows = shape[rowDim];
StringAttr colDimName = outDimNames[colDim];
StringAttr rowDimName = outDimNames[rowDim];
std::vector<std::vector<int>> bases2D;
for (int col = 1; col < numCols; col *= 2) {
bases2D.push_back({0, col});
}
for (int row = 1; row < numRows; row *= 2) {
int vec = shared.getVec();
int perPhase = shared.getPerPhase();
int maxPhase = shared.getMaxPhase();
int phase = (row / perPhase) % maxPhase;
int blockNo = row / maxPhase / perPhase % maxPhase;
int combinedPhase = phase ^ blockNo;
bases2D.push_back({row, (vec * combinedPhase) % numCols});
}
LinearLayout ctaLayout =
LinearLayout({{S("offset"), bases2D}}, {rowDimName, colDimName});
for (int i = 2; i < rank; i++) {
int dim = shared.getOrder()[i];
ctaLayout *=
LinearLayout::identity1D(shape[dim], S("offset"), outDimNames[dim]);
}
return combineCtaCgaWithShape(ctaLayout, shared.getCTALayout(), shape);
}
LinearLayout getCoreMatrixLinearLayout(NVMMASharedEncodingAttr shared,
bool disableSwizzle) {
auto *ctx = shared.getContext();
int elemBitWidth = shared.getElementBitWidth();
int tileWidthBytes = shared.getSwizzlingByteWidth();
int vec = shared.getVec();
int perPhase = shared.getPerPhase();
int maxPhase = shared.getMaxPhase();
int tileRows = 8;
int tileCols = 8 * tileWidthBytes / elemBitWidth;
bool isFp4Padded = shared.getFp4Padded();
std::vector<std::vector<int>> bases2D;
for (int col = 1; col < tileCols; col *= 2) {
if (isFp4Padded) {
int colPacked = col / 16 * 8 + col % 8;
bases2D.push_back({0, colPacked});
} else {
bases2D.push_back({0, col});
}
}
for (int row = 1; row < tileRows; row *= 2) {
if (disableSwizzle) {
bases2D.push_back({row, 0});
} else if (isFp4Padded) {
int colPadded = vec * ((row / perPhase) % maxPhase);
int colPacked = colPadded / 16 * 8 + colPadded % 8;
bases2D.push_back({row, colPacked});
} else {
bases2D.push_back({row, vec * ((row / perPhase) % maxPhase)});
}
}
auto outDimNames = standardOutDimNames(ctx, 2);
return LinearLayout({{S("offset"), bases2D}}, outDimNames);
}
}
LinearLayout nvmmaSharedToLinearLayout(ArrayRef<int64_t> shape,
NVMMASharedEncodingAttr shared,
bool disableSwizzle) {
MLIRContext *ctx = shared.getContext();
int rank = shape.size();
auto shapePerCTA = getShapePerCTA(shared, shape);
auto kOffset = S("offset");
auto tmaShape = triton::nvidia_gpu::getTMABlockShape(shared, shapePerCTA,
true);
if (shared.getSwizzlingByteWidth() == 0) {
auto outDimNames = standardOutDimNames(ctx, rank);
LinearLayout layout = LinearLayout::identity1D(tmaShape[rank - 1], kOffset,
outDimNames[rank - 1]);
for (int i = rank - 2; i >= 0; --i) {
layout *= LinearLayout::identity1D(tmaShape[i], kOffset, outDimNames[i]);
}
layout = ensureLayoutNotSmallerThan(layout, outDimNames, shapePerCTA);
return combineCtaCgaWithShape(layout, shared.getCTALayout(), shape);
}
assert(rank >= 2);
std::array<int64_t, 2> collapsedTmaShape{1, tmaShape.back()};
for (int i = 0; i + 1 < rank; i++)
collapsedTmaShape[0] *= tmaShape[i];
if (shared.getTransposed()) {
std::swap(collapsedTmaShape[0], collapsedTmaShape[1]);
}
auto tileLayout = getCoreMatrixLinearLayout(shared, disableSwizzle);
auto outDimNames = standardOutDimNames(ctx, 2);
auto kRow = outDimNames[0];
auto kCol = outDimNames[1];
auto tileRows = tileLayout.getOutDimSize(kRow);
auto tileCols = tileLayout.getOutDimSize(kCol);
int packingFactor = shared.getFp4Padded() ? 2 : 1;
if (collapsedTmaShape[1] * packingFactor < tileCols ||
collapsedTmaShape[0] < tileRows) {
llvm::errs() << "Illegal shared layout; expected collapsed shapePerCTA to "
"be at least ["
<< tileRows << ", " << (tileCols / packingFactor)
<< "], collapsedTmaShape: [" << collapsedTmaShape[0] << ", "
<< collapsedTmaShape[1] << "]\n";
llvm::report_fatal_error("Illegal shared layout");
}
auto layout =
ensureLayoutNotSmallerThan(tileLayout, outDimNames, collapsedTmaShape);
SmallVector<int64_t> maybeTransposedTmaShape = tmaShape;
if (shared.getTransposed()) {
std::rotate(maybeTransposedTmaShape.begin(),
maybeTransposedTmaShape.begin() + 1,
maybeTransposedTmaShape.end());
}
auto reshapedLayout = reshapeLayout(ctx, layout, maybeTransposedTmaShape);
if (shared.getTransposed()) {
SmallVector<int> order = {rank - 1};
for (int i = 0; i < rank - 1; i++) {
order.push_back(i);
}
reshapedLayout = transposeLinearLayout(reshapedLayout, order);
}
reshapedLayout = ensureLayoutNotSmallerThan(
reshapedLayout, standardOutDimNames(ctx, shapePerCTA.size()),
shapePerCTA);
return combineCtaCgaWithShape(reshapedLayout, shared.getCTALayout(), shape);
}
static LinearLayout broadcastedDotOperandLayout(MLIRContext *ctx,
ArrayRef<unsigned> shape,
ArrayRef<unsigned> order,
unsigned kDim,
StringAttr inDimName) {
auto rank = shape.size();
auto dimNames = standardOutDimNames(ctx, rank);
LinearLayout layout = LinearLayout::empty();
for (auto d : order) {
if (d == kDim) {
layout *= LinearLayout::zeros1D(shape[d], inDimName, dimNames[d]);
} else {
layout *= LinearLayout::identity1D(shape[d], inDimName, dimNames[d]);
}
}
return layout;
}
LinearLayout
AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
int rank = shape.size();
assert(rank == getRank());
bool hasBatchDim = rank == 3;
int mIndex = 0 + hasBatchDim;
int nIndex = 1 + hasBatchDim;
(void)mIndex, (void)nIndex;
MLIRContext *ctx = getContext();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
StringAttr kRegister = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
SmallVector<unsigned> order = getDefaultMmaOrder(*this);
auto dimM = outDimNames[order[1]];
auto dimN = outDimNames[order[0]];
unsigned mDim = getMDim();
unsigned nDim = getNDim();
auto elementType = getElementType();
int height = (elementType && elementType->isF64()) ? 1 : 4;
constexpr int warpSize = 64;
bool isTransposed = getIsTransposed();
if (mDim == 64 && nDim == 4)
assert(isTransposed && "64x4 mfma must be transposed");
int tiles = (mDim * nDim) / (warpSize * height);
LinearLayout tileLayout = LinearLayout::empty();
if (!isTransposed) {
LinearLayout regs = LinearLayout::identity1D(height, kRegister, dimM);
LinearLayout lanes = LinearLayout::identity1D(nDim, kLane, dimN) *
LinearLayout::identity1D(warpSize / nDim, kLane, dimM);
tileLayout = (regs * lanes);
if (tiles > 0)
tileLayout *= LinearLayout::identity1D(tiles, kRegister, dimM);
} else {
LinearLayout regs = LinearLayout::identity1D(height, kRegister, dimN);
LinearLayout lanes = LinearLayout::identity1D(mDim, kLane, dimM) *
LinearLayout::identity1D(warpSize / mDim, kLane, dimN);
tileLayout = (regs * lanes);
if (tiles > 0)
tileLayout *= LinearLayout::identity1D(tiles, kRegister, dimN);
}
tileLayout = tileLayout.transposeOuts({dimN, dimM});
auto tilesPerWarp = getTilesPerWarp();
auto warpsPerCTA = getWarpsPerCTA();
const unsigned tilesPerWarpM = tilesPerWarp[mIndex];
const unsigned tilesPerWarpN = tilesPerWarp[nIndex];
const unsigned warpsPerCTAM = warpsPerCTA[mIndex];
const unsigned warpsPerCTAN = warpsPerCTA[nIndex];
tileLayout *= LinearLayout::identity1D(tilesPerWarpN, kRegister, dimN);
tileLayout *= LinearLayout::identity1D(warpsPerCTAN, kWarp, dimN);
tileLayout *= LinearLayout::identity1D(
shape[nIndex] / (getNDim() * warpsPerCTAN * tilesPerWarpN), kRegister,
dimN);
tileLayout *= LinearLayout::identity1D(tilesPerWarpM, kRegister, dimM);
tileLayout *= LinearLayout::identity1D(warpsPerCTAM, kWarp, dimM);
if (hasBatchDim) {
assert(order[2] == 0);
tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]);
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
tileLayout *=
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[order[2]]);
}
return combineCtaCgaWithShape(tileLayout, getCTALayout(), shape);
}
LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout,
ArrayRef<int64_t> shape,
int32_t elemBitWidth) {
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
auto mDim = mfmaLayout.getMDim();
assert(mDim == 16 || mDim == 32);
bool isFP4 = false;
if (elemBitWidth == 4) {
elemBitWidth = 8;
isFP4 = true;
}
assert(elemBitWidth == 16 || elemBitWidth == 8);
auto rank = shape.size();
bool hasBatchDim = rank == 3;
int32_t kWidthDot = dotMfmaLayout.getKWidth();
auto kDim = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
int32_t kSize = shape[kDim];
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
MLIRContext *ctx = dotMfmaLayout.getContext();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
StringAttr kRegister = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
SmallVector<unsigned> order =
getOrderForDotOperand(dotMfmaLayout.getOpIdx(), rank, false);
std::vector<std::vector<int32_t>> registerBase;
std::vector<std::vector<int32_t>> laneBase;
auto populateFP4LL = [®isterBase, &laneBase](int kSize, int mDim) {
const bool isMfma32 = (mDim == 32);
registerBase.push_back({1, 0});
registerBase.push_back({2, 0});
registerBase.push_back({4, 0});
registerBase.push_back({0, 16});
const int kTileSize = isMfma32 ? 64 : 128;
for (int reg = kTileSize; reg < kSize; reg *= 2) {
registerBase.push_back({0, reg});
}
laneBase.push_back({0, 1});
laneBase.push_back({0, 2});
laneBase.push_back({0, 4});
laneBase.push_back({0, 8});
if (mDim == 16) {
laneBase.push_back({0, 32});
laneBase.push_back({0, 64});
} else {
assert(mDim == 32);
laneBase.push_back({8, 0});
laneBase.push_back({0, 32});
}
};
auto populateLL = [®isterBase, &laneBase](int elemBitWidth, int kSize,
int kWidthDot, int mDim) {
const int32_t ldsReadWidth = 64;
int32_t kWidthTransRead = ldsReadWidth / elemBitWidth;
const int elemByteWidth = elemBitWidth / 8;
const bool isMfma32 = (mDim == 32);
for (int i = 1; i < kWidthTransRead; i *= 2) {
registerBase.push_back({i, 0});
}
const int threadsPerSubtileNonK = 16 / kWidthTransRead;
const int threadsPerSubtileK = kWidthTransRead;
for (int i = 1; i < threadsPerSubtileNonK; i *= 2) {
laneBase.push_back({i * kWidthTransRead, 0});
}
for (int i = 1; i < threadsPerSubtileK; i *= 2) {
laneBase.push_back({0, i});
}
auto extendRegisterBaseForKDim = [&](int kTileSize,
int numSubtilesPerTile) {
const int regsPerTile = kWidthTransRead * numSubtilesPerTile;
int totalRegs = (kSize / kTileSize) * regsPerTile;
for (int reg = regsPerTile; reg < totalRegs; reg *= 2) {
registerBase.push_back({0, (reg / regsPerTile) * kTileSize});
}
};
const int kDoubleTileSize =
isMfma32 ? 32 / elemByteWidth : 64 / elemByteWidth;
const int kTileSize = kWidthDot * 64 / mDim;
const int numSubtilesPerTile = (kTileSize == kDoubleTileSize) ? 2 : 1;
if (numSubtilesPerTile == 2)
registerBase.push_back({0, threadsPerSubtileK});
extendRegisterBaseForKDim(kTileSize, numSubtilesPerTile);
std::vector<std::vector<int32_t>> laneBaseExt;
if (isMfma32) {
laneBaseExt = {{16, 0}, {0, numSubtilesPerTile * threadsPerSubtileK}};
} else {
laneBaseExt = {{0, numSubtilesPerTile * threadsPerSubtileK},
{0, 2 * numSubtilesPerTile * threadsPerSubtileK}};
}
laneBase.insert(laneBase.end(), laneBaseExt.begin(), laneBaseExt.end());
};
if (isFP4)
populateFP4LL(kSize, mDim);
else
populateLL(elemBitWidth, kSize, kWidthDot, mDim);
LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}},
{outDimNames[order[0]], outDimNames[order[1]]});
if (hasBatchDim) {
assert(order[2] == 0);
tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]);
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
}
auto warpOrder = getDefaultMmaOrder(mfmaLayout);
LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder);
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
warpLayout.transposeOuts(outDimNames);
return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape);
}
LinearLayout mfmaDotToLinearLayout(DotOperandEncodingAttr dotMfmaLayout,
ArrayRef<int64_t> shape) {
auto mfmaLayout = llvm::cast<AMDMfmaEncodingAttr>(dotMfmaLayout.getParent());
auto rank = shape.size();
bool hasBatchDim = rank == 3;
int mIndex = 0 + hasBatchDim;
int32_t kWidth = dotMfmaLayout.getKWidth();
auto kDimIndex = dotMfmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
auto tilesPerWarp = mfmaLayout.getTilesPerWarp();
auto tilePerWarpNonK = tilesPerWarp[kDimIndex];
auto mDim = mfmaLayout.getMDim();
auto nDim = mfmaLayout.getNDim();
auto opIdx = dotMfmaLayout.getOpIdx();
auto nonKDim = opIdx == 0 ? mDim : nDim;
constexpr int warpSize = 64;
int32_t kSize = shape[kDimIndex];
MLIRContext *ctx = dotMfmaLayout.getContext();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
StringAttr kRegister = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
auto order =
getOrderForDotOperand(dotMfmaLayout.getOpIdx(), rank, true);
auto dimK = outDimNames[order[0]];
auto dimNonK = outDimNames[order[1]];
auto warpOrder = getDefaultMmaOrder(mfmaLayout);
LinearLayout regs = LinearLayout::identity1D(kWidth, kRegister, dimK);
LinearLayout lanes =
LinearLayout::identity1D(nonKDim, kLane, dimNonK) *
LinearLayout::identity1D(warpSize / nonKDim, kLane, dimK);
LinearLayout tileLayout = regs * lanes;
int kTileSize = warpSize / nonKDim * kWidth;
if ((mDim == 64 && nDim == 4 && opIdx == 0) ||
(mDim == 4 && nDim == 64 && opIdx == 1)) {
tileLayout *= LinearLayout::identity1D(16, kRegister, dimK);
kTileSize *= 16;
}
if (kSize > kTileSize) {
tileLayout *= LinearLayout::identity1D(kSize / kTileSize, kRegister, dimK);
}
tileLayout *= LinearLayout::identity1D(tilePerWarpNonK, kRegister, dimNonK);
tileLayout = tileLayout.transposeOuts({dimK, dimNonK});
if (hasBatchDim) {
assert(order[2] == 0);
tileLayout *= LinearLayout::identity1D(1, kRegister, outDimNames[order[2]]);
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[order[2]]);
}
LinearLayout warpLayout = identityStandardND(kWarp, warpsPerCTA, warpOrder);
LinearLayout ctaLayout = tileLayout * warpLayout;
return combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape)
.transposeOuts(outDimNames);
}
LinearLayout
AMDWmmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
int rank = shape.size();
assert(rank == getRank());
bool hasBatchDim = rank == 3;
int mIndex = 0 + hasBatchDim;
int nIndex = 1 + hasBatchDim;
(void)mIndex, (void)nIndex;
SmallVector<unsigned> mnkDim = getMNKDimPerInstr();
unsigned mDim = mnkDim[0], nDim = mnkDim[1];
(void)mDim, (void)nDim;
assert(((shape[mIndex] == 1 || shape[mIndex] >= mDim) &&
(shape[nIndex] == 1 || shape[nIndex] >= nDim)) &&
"Unsupported tensor shape for given wmma layout");
MLIRContext *ctx = getContext();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
StringAttr kRegister = S("register");
StringAttr kLane = S("lane");
auto threadOrder = getMatrixOrder(rank, !getIsTransposed());
assert(threadOrder[0] == mIndex || threadOrder[0] == nIndex);
assert(threadOrder[1] == mIndex || threadOrder[1] == nIndex);
unsigned ver = getVersion();
assert(ver == 1 || ver == 2);
LinearLayout tileLayout =
ver == 1
? LinearLayout(
{{kRegister, { {0, 2}, {0, 4}, {0, 8}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 1}}}},
{outDimNames[threadOrder[0]], outDimNames[threadOrder[1]]})
: LinearLayout(
{{kRegister, {{0, 1}, {0, 2}, {0, 4}}},
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 8}}}},
{outDimNames[threadOrder[0]], outDimNames[threadOrder[1]]});
if (hasBatchDim) {
int batchIndex = 0;
tileLayout *=
LinearLayout::identity1D(1, kRegister, outDimNames[batchIndex]);
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[batchIndex]);
}
auto warpOrder = getDefaultMmaOrder(*this);
LinearLayout warpLayout =
identityStandardND(S("warp"), getWarpsPerCTA(), warpOrder);
auto repOrder = getRepOrder();
SmallVector<StringAttr> repDimNames;
for (auto dim : repOrder)
repDimNames.push_back(outDimNames[dim]);
LinearLayout ctaLayout = tileLayout.transposeOuts(repDimNames) *
warpLayout.transposeOuts(repDimNames);
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}
LinearLayout wmmaDotOperandToLinearLayout(DotOperandEncodingAttr dotWmmaLayout,
ArrayRef<int64_t> shape) {
auto wmmaLayout = llvm::cast<AMDWmmaEncodingAttr>(dotWmmaLayout.getParent());
auto rank = shape.size();
bool hasBatchDim = rank == 3;
auto kDim = dotWmmaLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
int32_t kSize = shape[kDim];
MLIRContext *ctx = dotWmmaLayout.getContext();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
StringAttr kRegister = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
auto laneOrder =
getOrderForDotOperand(dotWmmaLayout.getOpIdx(), rank, true);
std::vector<std::vector<int32_t>> registerBase;
const int32_t kWidth = dotWmmaLayout.getKWidth();
for (int i = 1; i < kWidth; i *= 2)
registerBase.push_back(std::vector<int32_t>{i, 0});
std::vector<std::vector<int32_t>> laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}};
switch (wmmaLayout.getVersion()) {
case 1:
laneBase.push_back({0, 0});
break;
case 2:
laneBase.push_back({kWidth, 0});
break;
default:
assert(false && "unexpected version");
}
LinearLayout tileLayout(
{{kRegister, registerBase}, {kLane, laneBase}},
{outDimNames[laneOrder[0]], outDimNames[laneOrder[1]]});
if (hasBatchDim) {
assert(laneOrder[2] == 0);
tileLayout *=
LinearLayout::identity1D(1, kRegister, outDimNames[laneOrder[2]]);
tileLayout *= LinearLayout::identity1D(1, kLane, outDimNames[laneOrder[2]]);
}
auto warpsPerCTA = wmmaLayout.getWarpsPerCTA();
auto warpOrder = getDefaultMmaOrder(wmmaLayout);
LinearLayout warpLayout =
broadcastedDotOperandLayout(ctx, warpsPerCTA, warpOrder, kDim, S("warp"));
auto repOrder = wmmaLayout.getRepOrderForOperand(dotWmmaLayout.getOpIdx());
SmallVector<StringAttr> repDimNames;
for (auto dim : repOrder)
repDimNames.push_back(outDimNames[dim]);
LinearLayout ctaLayout = tileLayout.transposeOuts(repDimNames) *
warpLayout.transposeOuts(repDimNames);
return combineCtaCgaWithShape(ctaLayout, wmmaLayout.getCTALayout(), shape);
}
LinearLayout
BlockedEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
MLIRContext *ctx = getContext();
auto order = getOrder();
LinearLayout ctaLayout =
identityStandardND(S("register"), getSizePerThread(), order) *
identityStandardND(S("lane"), getThreadsPerWarp(), order) *
identityStandardND(S("warp"), getWarpsPerCTA(), order);
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}
LinearLayout fmaDotToLinearLayout(DotOperandEncodingAttr operandLayout,
ArrayRef<int64_t> shape) {
int rank = shape.size();
auto blocked = cast<BlockedEncodingAttr>(operandLayout.getParent());
MLIRContext *ctx = operandLayout.getContext();
auto regOrder = blocked.getOrder();
auto threadOrder = blocked.getOrder();
auto warpOrder = blocked.getOrder();
auto repOrder = blocked.getRepOrder();
StringAttr kReg = S("register");
StringAttr kLane = S("lane");
StringAttr kWarp = S("warp");
auto threadSize = llvm::to_vector(blocked.getSizePerThread());
auto kDimIdx = operandLayout.getOpIdx() == 0 ? rank - 1 : rank - 2;
threadSize[kDimIdx] = shape[kDimIdx];
auto threadShape = blocked.getThreadsPerWarp();
auto warpShape = blocked.getWarpsPerCTA();
SmallVector<StringAttr> repDimNames =
permuteDimNames(standardOutDimNames(ctx, rank), repOrder);
auto registersLayout = identityStandardND(kReg, threadSize, regOrder);
auto lanesLayout = broadcastedDotOperandLayout(ctx, threadShape, threadOrder,
kDimIdx, kLane);
auto warpsLayout =
broadcastedDotOperandLayout(ctx, warpShape, warpOrder, kDimIdx, kWarp);
LinearLayout ctaLayout = registersLayout.transposeOuts(repDimNames) *
lanesLayout.transposeOuts(repDimNames) *
warpsLayout.transposeOuts(repDimNames);
return combineCtaCgaWithShape(ctaLayout, getCTALayout(operandLayout), shape);
}
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
unsigned kWidth, ArrayRef<unsigned> order,
ArrayRef<unsigned> repOrder) {
int rank = repOrder.size();
auto dimNames = standardOutDimNames(ctx, rank);
auto trivialShape = SmallVector<unsigned>(rank, 1);
LinearLayout ctaLayout =
identityStandardND(S("register"), trivialShape, repOrder);
assert(rank >= 2);
auto inner = order[0];
auto outer = order[1];
assert(tileShape.size() == rank);
int m = tileShape[outer];
int n = tileShape[inner];
assert(m % 8 == 0);
assert(n % (kWidth * 4) == 0);
auto outDimNames = llvm::to_vector(ctaLayout.getOutDimNames());
ctaLayout = ctaLayout *
LinearLayout::identity1D(kWidth, S("register"), dimNames[inner]) *
LinearLayout::identity1D(4, S("lane"), dimNames[inner]) *
LinearLayout::identity1D(8, S("lane"), dimNames[outer]) *
LinearLayout::identity1D(m / 8, S("register"), dimNames[outer]) *
LinearLayout::identity1D(n / (kWidth * 4), S("register"),
dimNames[inner]);
return ctaLayout;
}
LinearLayout
NvidiaMmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto ctx = getContext();
int rank = shape.size();
assert(rank == getRank());
SmallVector<unsigned> tileShape;
if (isAmpere()) {
tileShape = SmallVector<unsigned>(getInstrShape());
} else {
assert(isHopper());
auto instrShapeMNK = getInstrShape();
tileShape = SmallVector<unsigned>({instrShapeMNK[0], instrShapeMNK[1]});
}
constexpr auto kWidth = 2;
auto order = getDefaultMmaOrder(*this);
auto ctaLayout = nvidiaMmaTile(ctx, tileShape, kWidth, order, getRepOrder());
auto warpOrder = getMatrixOrder(rank, !isHopper());
ctaLayout *= identityStandardND(S("warp"), getWarpsPerCTA(), warpOrder)
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape);
}
LinearLayout nvidiaDotToLinearLayout(ArrayRef<int64_t> shape,
DotOperandEncodingAttr dot) {
int rank = shape.size();
auto mma = cast<NvidiaMmaEncodingAttr>(dot.getParent());
int kWidth = dot.getKWidth();
bool isA = dot.getOpIdx() == 0;
MLIRContext *ctx = mma.getContext();
SmallVector<unsigned> tileShape(rank, 1);
if (isA) {
tileShape[rank - 2] = 16;
tileShape[rank - 1] = kWidth * 8;
} else {
assert(mma.isAmpere());
tileShape[rank - 2] = kWidth * 8;
tileShape[rank - 1] = 8;
}
auto order = getOrderForDotOperand(dot.getOpIdx(), rank, true);
auto ctaLayout =
nvidiaMmaTile(ctx, tileShape, kWidth, order, dot.getRepOrder());
auto kDim = isA ? rank - 1 : rank - 2;
auto warpOrder = getMatrixOrder(rank, !mma.isHopper());
ctaLayout *= broadcastedDotOperandLayout(ctx, mma.getWarpsPerCTA(), warpOrder,
kDim, S("warp"))
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
return combineCtaCgaWithShape(ctaLayout, getCTALayout(dot), shape);
}
LinearLayout
DotOperandEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto parent = getParent();
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
return fmaDotToLinearLayout(*this, shape);
} else if (auto mfmaLayout = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
return mfmaDotToLinearLayout(*this, shape);
} else if (auto wmmaLayout = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
return wmmaDotOperandToLinearLayout(*this, shape);
} else {
auto mma = mlir::cast<NvidiaMmaEncodingAttr>(parent);
return nvidiaDotToLinearLayout(shape, *this);
}
}
LinearLayout SliceEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
MLIRContext *ctx = getContext();
SmallVector<int64_t> parentShape(shape);
parentShape.insert(parentShape.begin() + getDim(), 1);
LinearLayout parentLL = triton::gpu::toLinearLayout(parentShape, getParent());
auto outDimNames = standardOutDimNames(ctx, shape.size() + 1);
LinearLayout transform = LinearLayout::empty();
for (auto [idx, outDim] : llvm::enumerate(parentLL.getOutDimNames())) {
if (idx == getDim()) {
transform *= LinearLayout::zeros1D(parentLL.getOutDimSize(outDim), outDim,
outDimNames[0]);
} else {
transform *=
LinearLayout::identity1D(parentLL.getOutDimSize(outDim), outDim,
outDimNames[idx - (idx < getDim() ? 0 : 1)]);
}
}
LinearLayout sliceLL = parentLL.compose(transform);
auto bases = sliceLL.getBases();
std::vector<std::vector<int>> newRegBases;
for (const auto &basis : bases[S("register")]) {
if (llvm::any_of(basis, [](int b) { return b != 0; })) {
newRegBases.push_back(basis);
}
}
bases[S("register")] = newRegBases;
return LinearLayout(std::move(bases),
llvm::to_vector(sliceLL.getOutDimNames()));
}
LinearLayout tensorMemoryToLinearLayout(ArrayRef<int64_t> shape,
TensorMemoryEncodingAttr encoding) {
assert(shape.size() == 2);
auto *ctx = encoding.getContext();
auto kRow = S("row");
auto kCol = S("col");
auto dims = standardOutDimNames(ctx, 2);
if (encoding.getCTASplitN() > 1) {
auto split =
LinearLayout::identity1D(encoding.getCTASplitN(), kCol, dims[1]);
auto newEncoding = TensorMemoryEncodingAttr::get(
ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(),
encoding.getCTASplitM(), 1);
return tensorMemoryToLinearLayout(
{shape[0], shape[1] / encoding.getCTASplitN()}, newEncoding) *
split;
}
if (encoding.getCTASplitM() > 1) {
auto split =
LinearLayout::identity1D(encoding.getCTASplitM(), kCol, dims[0]);
auto newEncoding = TensorMemoryEncodingAttr::get(
ctx, encoding.getBlockM(), encoding.getBlockN(), encoding.getUnpacked(),
1, encoding.getCTASplitN());
return tensorMemoryToLinearLayout(
{shape[0] / encoding.getCTASplitM(), shape[1]}, newEncoding) *
split;
}
assert(encoding.getCTASplitM() == 1 && encoding.getCTASplitN() == 1);
auto blockM = encoding.getBlockM();
auto blockN = encoding.getBlockN();
assert(blockM == 64 || blockM == 128);
LinearLayout tile;
if (blockM == 64) {
tile = LinearLayout::identity1D(16, kRow, dims[0]) *
LinearLayout::identity1D(blockN, kCol, dims[1]);
auto bases = tile.getBases();
if (shape[0] > blockM) {
bases[kRow].push_back({64, 0});
} else if (shape[1] > blockN) {
bases[kRow].push_back({0, static_cast<int32_t>(blockN)});
} else {
bases[kRow].push_back({0, 0});
}
bases[kRow].push_back({16, 0});
bases[kRow].push_back({32, 0});
tile = LinearLayout(bases, dims);
} else {
tile = LinearLayout::identity1D(blockM, kRow, dims[0]) *
LinearLayout::identity1D(blockN, kCol, dims[1]);
}
auto repsM = shape[0] / tile.getOutDimSize(dims[0]);
auto repsN = shape[1] / tile.getOutDimSize(dims[1]);
assert(repsM >= 1 && repsN >= 1);
tile = tile * LinearLayout::identity1D(repsM, kCol, dims[0]) *
LinearLayout::identity1D(repsN, kCol, dims[1]);
return tile;
}
LinearLayout
tensorMemoryScalesToLinearLayout(ArrayRef<int64_t> shape,
TensorMemoryScalesEncodingAttr encoding) {
assert(shape.size() == 2);
auto *ctx = encoding.getContext();
auto kRow = S("row");
auto kCol = S("col");
auto dims = standardOutDimNames(ctx, 2);
auto tile =
LinearLayout::identity1D(std::min<int>(32, shape[0]), kRow, dims[0]) *
LinearLayout::identity1D(std::min<int>(4, shape[1]), kCol, dims[1]) *
LinearLayout::identity1D(std::max<int>(1, shape[0] / 32), kCol, dims[0]) *
LinearLayout::identity1D(std::max<int>(1, shape[1] / 4), kCol, dims[1]);
return tile;
}
LinearLayout TritonGPUDialect::toLinearLayout(ArrayRef<int64_t> shape,
Attribute layout) {
CacheKey key{std::vector<int64_t>(shape.begin(), shape.end()), layout};
if (auto result = llCache.get(key)) {
return *result;
}
LinearLayout result = LinearLayout::empty();
if (auto distributed = dyn_cast<DistributedEncodingTrait>(layout)) {
result = distributed.toLinearLayout(shape);
} else {
assert(llvm::all_of(shape,
[](int64_t dim) {
return llvm::isPowerOf2_32(dim) && dim >= 1;
}) &&
"shape must be a postive power of 2");
if (auto shared = dyn_cast<SwizzledSharedEncodingAttr>(layout)) {
result = swizzledSharedToLinearLayout(shape, shared);
} else if (auto shared = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
result = nvmmaSharedToLinearLayout(shape, shared);
} else if (auto sbl = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {
result = sharedToLinearLayoutAMDRotating(shape, sbl);
} else if (auto tensorMemoryEncoding =
dyn_cast<TensorMemoryEncodingAttr>(layout)) {
result = tensorMemoryToLinearLayout(shape, tensorMemoryEncoding);
} else if (auto tensorMemoryScalesEncoding =
dyn_cast<TensorMemoryScalesEncodingAttr>(layout)) {
result =
tensorMemoryScalesToLinearLayout(shape, tensorMemoryScalesEncoding);
} else {
assert(0 && "unknown layout");
}
}
llCache.set(std::move(key), result);
return result;
}
LinearLayout toLinearLayout(RankedTensorType type) {
return toLinearLayout(type.getShape(), type.getEncoding());
}
LinearLayout toLinearLayout(MemDescType type) {
auto shape = type.getAllocShape().take_back(type.getRank());
return toLinearLayout(shape, type.getEncoding());
}
LinearLayout toLinearLayout(TensorOrMemDesc type) {
if (auto ranked = dyn_cast<RankedTensorType>(type)) {
return toLinearLayout(ranked);
} else {
auto memDesc = cast<MemDescType>(type);
return toLinearLayout(memDesc);
}
}
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout) {
auto *ctx = layout.getContext();
return ctx->getLoadedDialect<TritonGPUDialect>()->toLinearLayout(shape,
layout);
}
LinearLayout getLayoutWithinBlock(const LinearLayout &layout) {
assert(!layout.getInDimNames().empty());
MLIRContext *ctx = layout.getInDimNames().begin()->getContext();
StringAttr kBlock = S("block");
assert(layout.hasInDim(kBlock));
auto bases = layout.getBases();
bases[kBlock] = {};
return LinearLayout(bases, llvm::to_vector<4>(layout.getOutDimNames()));
}
LinearLayout chooseShemLayoutForRegToRegConversion(
MLIRContext *ctx, ArrayRef<unsigned> tensorShape,
ArrayRef<unsigned> repShape, ArrayRef<unsigned> order) {
auto outDimNames = standardOutDimNames(ctx, tensorShape.size());
LinearLayout layout = LinearLayout::empty();
SmallVector<StringAttr> kRepDims;
SmallVector<StringAttr> kOffsetDims;
auto totalIters = 1;
auto totalOffsets = 1;
for (int i = 0; i < tensorShape.size(); i++) {
int dim = order[i];
StringAttr kIteration = S("iteration" + std::to_string(dim));
StringAttr kOffset = S("offset" + std::to_string(dim));
kRepDims.push_back(kIteration);
kOffsetDims.push_back(kOffset);
assert(llvm::isPowerOf2_32(repShape[dim]));
assert(llvm::isPowerOf2_32(tensorShape[dim]));
auto numIters = tensorShape[dim] / repShape[dim];
layout *=
LinearLayout::identity1D(repShape[dim], kOffset, outDimNames[dim]);
layout *= LinearLayout::identity1D(numIters, kIteration, outDimNames[dim]);
totalIters *= numIters;
totalOffsets *= repShape[dim];
}
StringAttr kOffset = S("offset");
StringAttr kIteration = S("iteration");
StringAttr kBlock = S("block");
SmallVector<StringAttr> newDims;
newDims.append(kOffsetDims.begin(), kOffsetDims.end());
newDims.append(kRepDims.begin(), kRepDims.end());
auto ret = layout.transposeIns(newDims);
return ret.reshapeIns(
{{kOffset, totalOffsets}, {kIteration, totalIters}, {kBlock, 1}});
}
LinearLayout chooseDsReadB64TrLayout(Attribute enc, ArrayRef<int64_t> shape,
int32_t elemBitWidth) {
auto dot = cast<DotOperandEncodingAttr>(enc);
return chooseDotDsReadB64TrLayout(dot, shape, elemBitWidth);
}
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
ArrayRef<int64_t> dotOperandShape,
unsigned mfmaMDim,
ArrayRef<unsigned> tilesPerWarp,
ArrayRef<unsigned> warpsPerCTA) {
using basisT = std::vector<std::vector<int32_t>>;
unsigned rank = dotOperandShape.size();
auto order = mlir::triton::gpu::getMatrixOrder(rank, true);
auto standardOutDims = standardOutDimNames(ctx, rank);
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
StringAttr kBlock = StringAttr::get(ctx, "block");
unsigned mnDim = dotOperandIdx == 0 ? rank - 2 : rank - 1;
unsigned tilePerWarpMN = tilesPerWarp[mnDim];
int32_t kSize = dotOperandShape[1];
std::vector<std::vector<int32_t>> registerBase;
std::vector<std::vector<int32_t>> laneBase;
auto threadsInKDim = mfmaMDim == 32 ? 2 : 4;
for (int32_t elem = threadsInKDim; elem < kSize; elem *= 2)
registerBase.emplace_back(std::vector<int32_t>{elem, 0});
for (int32_t elem = mfmaMDim; elem < tilePerWarpMN * mfmaMDim; elem *= 2)
registerBase.emplace_back(std::vector<int32_t>{0, elem});
if (mfmaMDim == 32) {
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {1, 0}};
} else {
assert(mfmaMDim == 16);
laneBase = {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}};
}
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
LinearLayout tileLayout({{kRegister, registerBase}, {kLane, laneBase}},
{outDimNames[order[0]], outDimNames[order[1]]});
SmallVector<unsigned> warpsPerCTANew =
(dotOperandIdx == 1)
? SmallVector<unsigned>{warpsPerCTA[1], warpsPerCTA[0]}
: SmallVector<unsigned>{warpsPerCTA[0], warpsPerCTA[1]};
SmallVector<unsigned> warpOrder = (dotOperandIdx == 1)
? SmallVector<unsigned>{0, 1}
: SmallVector<unsigned>{1, 0};
LinearLayout warpLayout =
identityStandardND(kWarp, warpsPerCTANew, warpOrder);
LinearLayout ctaLayout = tileLayout.transposeOuts(outDimNames) *
warpLayout.transposeOuts(outDimNames);
auto ctaLay = CTALayoutAttr::get(ctx, {1, 1},
{1, 1}, {1, 0});
auto finalLay = combineCtaCgaWithShape(ctaLayout, ctaLay, dotOperandShape);
return finalLay;
}
std::optional<LinearLayout>
chooseMfmaLikeStoreLayout(RankedTensorType valType) {
if (!isa<AMDMfmaEncodingAttr>(valType.getEncoding()))
return {};
auto mfmaLayout = cast<AMDMfmaEncodingAttr>(valType.getEncoding());
bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32;
bool isMfma16 = mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16;
auto valShape = valType.getShape();
bool validForMfma16 = isMfma16 && valShape.back() >= 16 * 2 &&
mfmaLayout.getWarpsPerCTA().back() == 1;
Type elemType = valType.getElementType();
if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) &&
mfmaLayout.getVersion() == 4 && mfmaLayout.getIsTransposed() &&
(isMfma32 || validForMfma16)))
return {};
LinearLayout mfmaLL = mfmaLayout.toLinearLayout(valShape);
auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames());
StringAttr dimM = mfmaOutDims[0];
StringAttr dimN = mfmaOutDims[1];
auto swapLL = LinearLayout::empty();
swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM);
clang-format off
In transposed mfma32 layout, Each thread holds 4 consecutive values along N
dim. We want to exchange column 4-7 (owned by thread 32-63, BLK0) and column
8-11 (owned by thread 0-31, BLK1) every 16 columns to make each thread holds 8
elements. This would mean exchange the 2nd and 3rd basis vector from an
identity linear layout on tensor elements.
Correspondingly, the transposed mfma16 layout, the output of
transposed of mfma16x16 is:
N/register
M/Lane v0 v1 v2 v3 v4 v5 v6 v7
-------------------------------------------------------------------------
row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
-------------------------------------------------------------------------
row1: 16-31 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
-------------------------------------------------------------------------
row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
-------------------------------------------------------------------------
row3: 48-63 | tile-0 | tile-0 | tile-0 | tile-0 | tile-1 | tile-1 | tile-1 | tile-1 |
-------------------------------------------------------------------------
which means:
The columns from v0 to v3 are in the one output of mfma16x16 and
the columns from v4 to v7 are in the one output of mfma16x16,
The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor,
N/register
-----------------------------------------------
M/lane |(0, 0) ... (0, 3) | (0, 16) ... (0, 19) |
|.... | sub-tensor-0 |
|(15, 0) ... (15, 3) | (15, 16) ... (15, 19) |
-----------------------------------------------
|(0, 4) ... (0, 7) | (0, 20) ... (0, 23) |
|sub-tensor-1 | .... |
|(15, 0) ... (15, 3) | (15, 20) ... (15, 23) |
-----------------------------------------------
|(0, 8) ... (0, 11)| (0, 24) ... (0, 27) |
|.... | sub-tensor-2 |
|(15, 8) ... (15, 11)| (15, 24) ... (15, 27) |
-----------------------------------------------
|(0, 12) ... (0, 15)| (0, 28) ... (0, 31) |
|sub-tensor-3 | .... |
|(15, 12) ... (15, 15)| (15, 28) ... (15, 31) |
-----------------------------------------------
The basis vector for lane and register are:
Register = {{0, 1}, {0, 2}}
Lane = {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {0, 4}, {0, 8}}
With this layout, only 4xfp16 can be packed in the final global store.
To use 128-bits global store, we need to pack 8 elements, which means the layout looks like:
N/register
M/Lane v0 v1 v2 v3 v4 v5 v6 v7
-------------------------------------------------------------------------
row0: 0-15 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 |
-------------------------------------------------------------------------
row1: 16-31 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 |
-------------------------------------------------------------------------
row2: 32-47 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 | tile-0 |
-------------------------------------------------------------------------
row3: 48-63 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 | tile-1 |
-------------------------------------------------------------------------
The following graph is the same as the one above, execept the tile number is replaced with coordinates in the tenor:
N/register
-----------------------------------------------
|(0, 0) ... (0, 3) | (0, 4) ... (0, 7) |
|.... | sub-tensor-1 |
|(15, 0) ... (15, 3) | (15, 16) ... (15, 19) |
-----------------------------------------------
|(0, 16) ... (0, 19) | (0, 20) ... (0, 23) |
|sub-tensor-0 | .... |
|(15, 16) ... (15, 19)| (15, 20) ... (15, 23) |
-----------------------------------------------
|(0, 8) ... (0, 11)| (0, 12) ... (0, 15) |
|.... | sub-tensor-3 |
|(15, 8) ... (15, 11)| (15, 12) ... (15, 15) |
-----------------------------------------------
|(0, 24) ... (0, 27)| (0, 28) ... (0, 31) |
|sub-tensor-2 | .... |
|(15, 24) ... (15, 27)| (15, 28) ... (15, 31) |
-----------------------------------------------
which means we need to exchange sub-tensor-0 with sub-tensor-1 and sub-tensor-2 and sub-tensor-3.
And basis vector for lane and register are:
Register = {{0, 1}, {0, 2}, {0, 4}}
Lane = {{1, 0}, {2, 0, [4, 0}, {8, 0}, {0, 16}, {0, 8}}
The steps to get this layout are, firstly we check the last dim of WarpsPerCTA is 1, so we can use v_permlane16.
Then, we exchange the 2nd and 4th elements in the basis vector of an identity linear and then it will be composed with
the original mfma16 LL.
clang-format on
*/
auto destIdxInBases = isMfma32 ? 3 : 4;
std::vector<std::vector<int32_t>> dimNBases(mfmaLL.getOutDimSizeLog2(dimN));
std::generate(dimNBases.begin(), dimNBases.end(),
[i = 0]() mutable { return std::vector<int32_t>{1 << i++}; });
std::swap(dimNBases[2], dimNBases[destIdxInBases]);
swapLL *= LinearLayout({{dimN, dimNBases}}, {dimN});
return mfmaLL.compose(swapLL);
}
LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType,
int numWarps) {
assert(numWarps == 4 || numWarps == 8);
MLIRContext *ctx = scaleType.getContext();
using basisT = std::vector<std::vector<int32_t>>;
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
int64_t M = scaleType.getDimSize(0);
int64_t N = scaleType.getDimSize(1);
auto CTALayout = getCTALayout(scaleType.getEncoding());
basisT regBase;
for (int i = 1; i < 4; i = i << 1) {
if (i >= N)
regBase.push_back({0, 0});
else
regBase.push_back({0, i});
}
basisT laneBase = {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}};
basisT warpBase = {{0, 0}, {0, 0}};
for (int i = 32; i < M; i = i << 1) {
regBase.push_back({i, 0});
}
for (int i = 4; i < N; i = i << 1) {
regBase.push_back({0, i});
}
if (numWarps == 8) {
warpBase.push_back(regBase.back());
regBase.pop_back();
}
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, 2);
auto regLanes =
LinearLayout({{kRegister, regBase}, {kLane, laneBase}, {kWarp, warpBase}},
{outDimNames[0], outDimNames[1]});
return combineCtaCgaWithShape(regLanes, CTALayout, scaleType.getShape());
}
std::optional<LinearLayout>
getTmemLoadStoreLayout16x256(int M, int N, RankedTensorType oldType,
int numWarps) {
if (numWarps == 8 && M == 64 && N <= 16 &&
oldType.getElementTypeBitWidth() < 32) {
return {};
}
assert(numWarps == 4 || numWarps == 8);
auto ctaLayout = getCTALayout(oldType.getEncoding());
SmallVector<int64_t> shape = getShapePerCTA(oldType);
MLIRContext *ctx = ctaLayout.getContext();
using basisT = std::vector<std::vector<int32_t>>;
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, 2);
unsigned numElementsPerThread = 256 / oldType.getElementTypeBitWidth();
int kWidth = 64 / oldType.getElementTypeBitWidth();
LinearLayout innerTile =
nvidiaMmaTile(ctx, {8, numElementsPerThread}, kWidth, {1, 0}, {0, 1});
innerTile =
innerTile * LinearLayout::identity1D(2, kRegister, outDimNames[0]);
bool distributeMAlongWarps = false;
bool distributeNAlongWarps = false;
if (numWarps == 8) {
if (shape[0] > 128) {
distributeMAlongWarps = true;
} else {
distributeNAlongWarps = true;
}
}
int nBase = numElementsPerThread;
int maxRegN =
std::min(N, distributeNAlongWarps ? (int)shape[1] / 2 : (int)shape[1]);
if (maxRegN / nBase > 1) {
innerTile = innerTile * LinearLayout::identity1D(maxRegN / nBase, kRegister,
outDimNames[1]);
}
if (M != 64) {
innerTile =
innerTile * LinearLayout::identity1D(2, kRegister, outDimNames[0]);
}
innerTile = innerTile * LinearLayout::identity1D(4, kWarp, outDimNames[0]);
int numMRegDim = std::min(128, (int)shape[0]) / M;
if (numMRegDim > 1) {
innerTile = innerTile *
LinearLayout::identity1D(numMRegDim, kRegister, outDimNames[0]);
}
int nextDim = 128;
if (distributeMAlongWarps) {
innerTile = innerTile * LinearLayout::identity1D(2, kWarp, outDimNames[0]);
nextDim <<= 1;
}
numMRegDim = shape[0] / nextDim;
if (numMRegDim > 1) {
innerTile = innerTile *
LinearLayout::identity1D(numMRegDim, kRegister, outDimNames[0]);
}
int maxN = distributeNAlongWarps ? shape[1] / 2 : shape[1];
int numNRegDim = maxN / maxRegN;
if (numNRegDim > 1) {
innerTile = innerTile *
LinearLayout::identity1D(numNRegDim, kRegister, outDimNames[1]);
}
if (distributeNAlongWarps) {
innerTile = innerTile * LinearLayout::identity1D(2, kWarp, outDimNames[1]);
}
return combineCtaCgaWithShape(innerTile, ctaLayout, oldType.getShape());
}
LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType,
int numWarps) {
assert(numWarps == 8);
auto ctaLayout = getCTALayout(oldType.getEncoding());
SmallVector<int64_t> shape = getShapePerCTA(oldType);
MLIRContext *ctx = ctaLayout.getContext();
using basisT = std::vector<std::vector<int32_t>>;
StringAttr kRegister = StringAttr::get(ctx, "register");
StringAttr kLane = StringAttr::get(ctx, "lane");
StringAttr kWarp = StringAttr::get(ctx, "warp");
basisT laneBase;
assert(M == 128);
for (int i = 1; i < 16; i = i << 1) {
laneBase.push_back({i, 0});
}
basisT regBase;
for (int i = 1; i < N / 2; i = i << 1) {
regBase.push_back({0, i});
}
laneBase.push_back({0, N / 2});
for (int i = N; i < shape[1]; i = i << 1) {
regBase.push_back({0, i});
}
for (int i = M; i < shape[0]; i = i << 1) {
regBase.push_back({i, 0});
}
basisT warpBase = {{32, 0}, {64, 0}, {16, 0}};
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, 2);
auto regLanes =
LinearLayout({{kRegister, regBase}, {kLane, laneBase}, {kWarp, warpBase}},
{outDimNames[0], outDimNames[1]});
return combineCtaCgaWithShape(regLanes, ctaLayout, oldType.getShape());
}
}