#include "triton/Dialect/Triton/IR/Dialect.h"
#include <cstdint>
#include <numeric>
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/Support/LLVM.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/Triton/IR/Interfaces.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LayoutUtility.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/MathExtras.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.cpp.inc"
#include "triton/Dialect/TritonGPU/IR/TypeInterfaces.cpp.inc"
using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::gpu;
namespace mlir {
namespace triton {
namespace gpu {
LinearEncodingAttr TritonGPUDialect::toLinearEncoding(ArrayRef<int64_t> shape,
Attribute layout) {
std::vector<int64_t> allocationShape;
CacheKey key{std::vector<int64_t>(shape.begin(), shape.end()), layout};
if (auto result = leCache.get(key)) {
return *result;
}
auto linearLayout = toLinearLayout(shape, layout);
auto linearEncoding =
LinearEncodingAttr::get(layout.getContext(), std::move(linearLayout));
leCache.set(key, linearEncoding);
return linearEncoding;
}
LinearEncodingAttr toLinearEncoding(DistributedEncodingTrait layout,
ArrayRef<int64_t> shape) {
auto *ctx = layout.getContext();
return ctx->getLoadedDialect<TritonGPUDialect>()->toLinearEncoding(shape,
layout);
}
LinearEncodingAttr toLinearEncoding(RankedTensorType type) {
auto *ctx = type.getContext();
return ctx->getLoadedDialect<TritonGPUDialect>()->toLinearEncoding(
type.getShape(), type.getEncoding());
}
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
return toLinearEncoding(cast<DistributedEncodingTrait>(layout), shape)
.getTotalElemsPerThread(shape);
}
SmallVector<unsigned> getElemsPerThread(Attribute layout,
ArrayRef<int64_t> shape) {
return toLinearEncoding(cast<DistributedEncodingTrait>(layout), shape)
.getElemsPerThread(shape);
}
SmallVector<unsigned> getElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || isa<triton::PointerType>(type))
return SmallVector<unsigned>(1, 1);
auto tensorType = cast<RankedTensorType>(type);
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
}
unsigned getTotalElemsPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || isa<triton::PointerType>(type))
return 1;
auto tensorType = cast<RankedTensorType>(type);
return getTotalElemsPerThread(tensorType.getEncoding(),
tensorType.getShape());
}
SmallVector<unsigned> getThreadsPerWarp(Attribute layout,
ArrayRef<int64_t> shape) {
return toLinearEncoding(cast<DistributedEncodingTrait>(layout), shape)
.getThreadsPerWarp();
}
SmallVector<unsigned> getWarpsPerCTA(Attribute layout,
ArrayRef<int64_t> shape) {
return toLinearEncoding(cast<DistributedEncodingTrait>(layout), shape)
.getWarpsPerCTA();
}
SmallVector<unsigned> getContigPerThread(RankedTensorType type) {
return toLinearEncoding(type).getContigPerThread();
}
bool isExpensiveView(Type srcType, Type dstType) {
auto tensorSrcType = cast<RankedTensorType>(srcType);
auto tensorDstType = cast<RankedTensorType>(dstType);
auto llSrc = toLinearLayout(tensorSrcType);
auto llDst = toLinearLayout(tensorDstType);
for (auto [srcMask, dstMask] :
llvm::zip(llSrc.getFreeVariableMasks(), llDst.getFreeVariableMasks())) {
assert(srcMask.first == dstMask.first);
if (srcMask.second != dstMask.second)
return true;
}
return getTotalElemsPerThread(srcType) != getTotalElemsPerThread(dstType);
}
* Erase dim and decrease all values larger than dim by 1.
* Example: order = [0, 2, 4, 3, 1], dim = 2
* resOrder = [0, 3, 2, 1]
*/
static SmallVector<unsigned> eraseOrder(ArrayRef<unsigned> order,
unsigned dim) {
unsigned rank = order.size();
assert(dim < rank && "Invalid dim to erase");
SmallVector<unsigned> resOrder;
for (unsigned i : order)
if (i < dim)
resOrder.push_back(i);
else if (i > dim)
resOrder.push_back(i - 1);
return resOrder;
}
SmallVector<unsigned> getMatrixOrder(unsigned rank, bool rowMajor) {
SmallVector<unsigned> order(rank);
if (rank < 2) {
return order;
}
std::iota(order.rbegin(), order.rend(), 0);
if (!rowMajor) {
std::swap(order[0], order[1]);
}
return order;
}
SmallVector<unsigned> getOrderForDotOperand(unsigned opIdx, unsigned rank,
bool kContig) {
assert(opIdx == 0 || opIdx == 1);
auto rowMajor = bool(opIdx) != kContig;
return getMatrixOrder(rank, rowMajor);
}
SmallVector<unsigned> getRepOrder(RankedTensorType type) {
auto layout = type.getEncoding();
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
return distributedLayout.getRepOrder();
else
llvm::report_fatal_error("Unimplemented usage of getRepOrder");
return {};
}
SmallVector<unsigned> getOrder(SharedEncodingTrait layout,
ArrayRef<int64_t> shape) {
if (auto swizzledLayout = dyn_cast<SwizzledSharedEncodingAttr>(layout)) {
return llvm::to_vector(swizzledLayout.getOrder());
}
if (auto paddedEnc = dyn_cast<PaddedSharedEncodingAttr>(layout)) {
return llvm::to_vector(paddedEnc.getOrder());
}
if (auto sharedLayout = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
if (shape.size() == 1) {
return {0};
}
return getMatrixOrder(shape.size(), !sharedLayout.getTransposed());
}
if (auto sharedLayout = dyn_cast<AMDRotatingSharedEncodingAttr>(layout)) {
return llvm::to_vector(sharedLayout.getOrder());
}
llvm::report_fatal_error("Unimplemented usage of getOrder for MemDescType");
return {};
}
SmallVector<unsigned> getOrder(DistributedEncodingTrait layout,
ArrayRef<int64_t> shape) {
return toLinearEncoding(layout, shape).getOrder();
}
SmallVector<unsigned> getOrderForMemory(DistributedEncodingTrait layout,
ArrayRef<int64_t> shape) {
auto linear = toLinearEncoding(layout, shape);
auto order = linear.getOrder();
auto threadOrder = linear.getThreadOrder();
if (order == threadOrder) {
return order;
}
auto contig = linear.getElemsPerThread(shape);
if (contig[threadOrder[0]] == 1) {
return threadOrder;
}
return order;
}
SmallVector<unsigned> getThreadOrder(DistributedEncodingTrait layout,
ArrayRef<int64_t> shape) {
return toLinearEncoding(layout, shape).getThreadOrder();
}
SmallVector<unsigned> getWarpOrder(DistributedEncodingTrait layout,
ArrayRef<int64_t> shape) {
return toLinearEncoding(layout, shape).getWarpOrder();
}
CTALayoutAttr getCTALayout(Attribute layout) {
if (auto ttgLayout = mlir::dyn_cast<LayoutEncodingTrait>(layout)) {
return CTALayoutAttr::get(layout.getContext(), getCTAsPerCGA(ttgLayout),
getCTASplitNum(ttgLayout),
getCTAOrder(ttgLayout));
}
llvm::report_fatal_error("Unimplemented usage of getCTALayout");
return {};
}
SmallVector<unsigned> getCTAsPerCGA(Attribute layout) {
if (auto ttgLayout = mlir::dyn_cast<LayoutEncodingTrait>(layout))
return ttgLayout.getCTAsPerCGA();
llvm::report_fatal_error("Unimplemented usage of getCTAsPerCGA");
}
SmallVector<unsigned> getCTASplitNum(Attribute layout) {
SmallVector<unsigned> res;
if (auto ttgLayout = mlir::dyn_cast<LayoutEncodingTrait>(layout)) {
return ttgLayout.getCTASplitNum();
} else if (auto tmemLayout =
mlir::dyn_cast<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
layout)) {
res.resize(2);
res[0] = tmemLayout.getCTASplitM();
res[1] = tmemLayout.getCTASplitN();
} else if (auto tmemScaleLayout = mlir::dyn_cast<
triton::nvidia_gpu::TensorMemoryScalesEncodingAttr>(layout)) {
res.resize(2);
res[0] = tmemScaleLayout.getCTASplitM();
res[1] = tmemScaleLayout.getCTASplitN();
} else {
assert(false && "Unimplemented usage of getCTASplitNum");
}
return res;
}
SmallVector<unsigned> getCTAOrder(Attribute layout) {
SmallVector<unsigned> res;
if (auto ttgLayout = mlir::dyn_cast<LayoutEncodingTrait>(layout)) {
res = ttgLayout.getCTAOrder();
} else {
llvm::report_fatal_error("Unimplemented usage of getCTAOrder");
}
return res;
}
SmallVector<int64_t> getShapePerCTA(ArrayRef<unsigned> CTASplitNum,
ArrayRef<int64_t> shape) {
unsigned rank = shape.size();
auto splitNum = llvm::to_vector(CTASplitNum);
if (splitNum.size() <= rank) {
splitNum.insert(splitNum.begin(), rank - splitNum.size(), 1);
} else {
splitNum =
llvm::to_vector(llvm::drop_begin(splitNum, splitNum.size() - rank));
}
SmallVector<int64_t> shapePerCTA(rank);
for (unsigned i = 0; i < rank; ++i) {
shapePerCTA[i] = shape[i] / std::min<unsigned>(shape[i], splitNum[i]);
}
return shapePerCTA;
}
SmallVector<int64_t> getShapePerCTA(Attribute layout, ArrayRef<int64_t> shape) {
return getShapePerCTA(getCTASplitNum(layout), shape);
}
SmallVector<int64_t> getAllocationShapePerCTA(Attribute layout,
ArrayRef<int64_t> shapeLogical) {
SmallVector<int64_t> shape(shapeLogical);
if (auto sharedMMALayout = dyn_cast<NVMMASharedEncodingAttr>(layout)) {
if (sharedMMALayout.getFp4Padded()) {
auto packedAxis = getOrder(sharedMMALayout, shapeLogical)[0];
shape[packedAxis] *= 2;
}
}
return getShapePerCTA(layout, shape);
}
SmallVector<int64_t> getShapePerCTA(Type type) {
auto tensorType = cast<TensorOrMemDesc>(type);
return getShapePerCTA(tensorType.getEncoding(), tensorType.getShape());
}
SmallVector<int64_t> getAllocationShapePerCTA(Type type) {
auto tensorType = cast<TensorOrMemDesc>(type);
return getAllocationShapePerCTA(tensorType.getEncoding(),
tensorType.getShape());
}
unsigned getNumCTAs(Attribute layout) {
return product<unsigned>(getCTAsPerCGA(layout));
}
bool isExpensiveCat(CatOp cat, Attribute targetEncoding) {
RankedTensorType tensorTy = cat.getType();
auto totalElemsPerThread = gpu::getTotalElemsPerThread(tensorTy);
auto shape = tensorTy.getShape();
auto newTotalElemsPerThread =
gpu::getTotalElemsPerThread(targetEncoding, shape);
return newTotalElemsPerThread < totalElemsPerThread;
}
static LogicalResult
verifyLayoutOrder(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<unsigned> order) {
if (!isPermutationOfIota(order)) {
return emitError()
<< "order must be a permutation of 0..(rank-1), but was [" << order
<< "]";
}
return success();
}
LogicalResult CTALayoutAttr::verify(
function_ref<InFlightDiagnostic()> emitError, ArrayRef<unsigned> CTAsPerCGA,
ArrayRef<unsigned> CTASplitNum, ArrayRef<unsigned> CTAOrder) {
if (!llvm::all_equal(
{CTAsPerCGA.size(), CTASplitNum.size(), CTAOrder.size()})) {
return emitError() << "CTAsPerCGA, CTASplitNum, and CTAOrder must all have "
"the same rank.";
}
if (failed(verifyLayoutOrder(emitError, CTAOrder)))
return failure();
if (llvm::any_of(CTAsPerCGA, [](unsigned x) { return x == 0; })) {
return emitError() << "Every element in CTAsPerCGA must be greater than 0.";
}
if (llvm::any_of(CTASplitNum, [](unsigned x) { return x == 0; })) {
return emitError()
<< "Every element in CTASplitNum must be greater than 0.";
}
return success();
}
LogicalResult
BlockedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<unsigned> sizePerThread,
ArrayRef<unsigned> threadsPerWarp,
ArrayRef<unsigned> warpsPerCTA,
ArrayRef<unsigned> order, CTALayoutAttr CTALayout) {
if (!llvm::all_equal({sizePerThread.size(), threadsPerWarp.size(),
warpsPerCTA.size(), order.size()})) {
return emitError() << "sizePerThread, threadsPerWarp, warpsPerCTA, and "
"order must all have the same rank.";
}
if (llvm::any_of(sizePerThread,
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
return emitError()
<< "Every element in sizePerThread must be a power of two.";
}
if (llvm::any_of(threadsPerWarp,
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
return emitError()
<< "Every element in threadsPerWarp must be a power of two.";
}
if (llvm::any_of(warpsPerCTA,
[](unsigned x) { return !llvm::isPowerOf2_64(x); })) {
return emitError()
<< "Every element in warpsPerCTA must be a power of two.";
}
if (order.size() != CTALayout.getRank()) {
return emitError() << "BlockedEncodingAttr and CTALayout's fields must "
"have the same rank.";
}
return verifyLayoutOrder(emitError, order);
}
triton::gpu::BlockedEncodingAttr
getDefaultBlockedEncoding(MLIRContext *context, ArrayRef<int64_t> shape,
int numWarps, int threadsPerWarp, int numCTAs) {
int rank = shape.size();
llvm::SmallVector<unsigned> order(rank);
std::iota(order.begin(), order.end(), 0);
std::reverse(order.begin(), order.end());
llvm::SmallVector<unsigned> sizePerThread(rank, 1);
triton::gpu::BlockedEncodingAttr encoding =
triton::gpu::BlockedEncodingAttr::get(context, shape, sizePerThread,
order, numWarps, threadsPerWarp,
numCTAs);
return encoding;
}
LogicalResult tryJoinOnAxis(MLIRContext *ctx, const LinearLayout &inLl,
LinearLayout &outLl, bool fwdInference, int axis,
std::optional<Location> loc) {
auto kRegister = StringAttr::get(ctx, "register");
auto outDims = llvm::to_vector(inLl.getOutDimNames());
if (fwdInference) {
auto split = LinearLayout::identity1D(2, kRegister, outDims[axis]);
outLl = split * inLl;
} else {
bool found = false;
LinearLayout::BasesT newBases;
for (const auto &basesDim : inLl.getBases()) {
std::vector<std::vector<int32_t>> newBasesDim;
for (auto base : basesDim.second) {
if (base[axis] == 1 && basesDim.first == kRegister) {
found = true;
continue;
}
base[axis] /= 2;
newBasesDim.push_back(std::move(base));
}
newBases.insert({basesDim.first, std::move(newBasesDim)});
}
if (!found)
return emitOptionalError(loc,
"Fp4ToFpOp/SplitOp requires at least 2 elements "
"per thread in the axis/last dimension");
outLl = LinearLayout(std::move(newBases), std::move(outDims));
}
return success();
}
}
}
}
static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr,
unsigned &value, StringRef desc) {
auto intAttr = mlir::dyn_cast<IntegerAttr>(attr);
if (!intAttr) {
parser.emitError(parser.getNameLoc(), "expected an integer type in ")
<< desc;
return failure();
}
if (intAttr.getType().isSignedInteger()) {
int64_t attrVal = intAttr.getSInt();
if (attrVal < 0) {
parser.emitError(parser.getNameLoc(),
"expected an unsigned integer value in ")
<< desc;
return failure();
}
value = attrVal;
} else if (intAttr.getType().isSignlessInteger()) {
int64_t attrVal = intAttr.getInt();
if (attrVal < 0) {
parser.emitError(parser.getNameLoc(),
"expected an unsigned integer value in ")
<< desc;
return failure();
}
value = attrVal;
} else {
value = intAttr.getUInt();
}
return success();
}
static LogicalResult parseBoolAttrValue(AsmParser &parser, Attribute attr,
bool &value, StringRef desc) {
auto boolAttr = mlir::dyn_cast<BoolAttr>(attr);
if (!boolAttr) {
parser.emitError(parser.getNameLoc(), "expected a bool type in ") << desc;
return failure();
}
value = boolAttr.getValue();
return success();
}
static LogicalResult parseIntArrayAttr(AsmParser &parser,
const NamedAttribute &attr,
SmallVector<unsigned> &res,
StringRef desc) {
auto arrayAttr = mlir::dyn_cast<ArrayAttr>(attr.getValue());
if (!arrayAttr) {
parser.emitError(parser.getNameLoc(), "expected an array for ") << desc;
return failure();
}
for (Attribute i : arrayAttr) {
unsigned value;
if (parseIntAttrValue(parser, i, value, desc).failed())
return failure();
res.push_back(value);
}
return success();
};
static LogicalResult parseUInt(AsmParser &parser, const NamedAttribute &attr,
unsigned &value, StringRef desc) {
return parseIntAttrValue(parser, attr.getValue(), value, desc);
};
static LogicalResult parseBool(AsmParser &parser, const NamedAttribute &attr,
bool &value, StringRef desc) {
return parseBoolAttrValue(parser, attr.getValue(), value, desc);
};
static LogicalResult parseType(AsmParser &parser, const NamedAttribute &attr,
std::optional<Type> &value, StringRef desc) {
auto typeAttr = mlir::dyn_cast<TypeAttr>(attr.getValue());
if (!typeAttr) {
parser.emitError(parser.getNameLoc(), "expected a Type in ") << desc;
return failure();
}
value = typeAttr.getValue();
return success();
}
static void maybePrintCTALayout(mlir::MLIRContext *context,
mlir::AsmPrinter &printer, CTALayoutAttr layout,
unsigned rank) {
if (layout != CTALayoutAttr::getDefault(context, rank)) {
printer << ", CTAsPerCGA = [" << ArrayRef(layout.getCTAsPerCGA()) << "]"
<< ", CTASplitNum = [" << ArrayRef(layout.getCTASplitNum()) << "]"
<< ", CTAOrder = [" << ArrayRef(layout.getCTAOrder()) << "]";
}
}
#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc"
static std::optional<CTALayoutAttr> getCTALayoutOrError(
AsmParser &parser, std::optional<SmallVector<unsigned>> CTAsPerCGA,
std::optional<SmallVector<unsigned>> CTASplitNum,
std::optional<SmallVector<unsigned>> CTAOrder, unsigned rank) {
if (CTAsPerCGA && CTASplitNum && CTAOrder) {
return CTALayoutAttr::get(parser.getContext(), *CTAsPerCGA, *CTASplitNum,
*CTAOrder);
}
if (!CTAsPerCGA && !CTASplitNum && !CTAOrder) {
return CTALayoutAttr::getDefault(parser.getContext(), rank);
}
parser.emitError(parser.getNameLoc(), "CTAsPerCGA, CTASplitNum, and CTAOrder "
"must all be present or all be absent");
return std::nullopt;
}
Attribute BlockedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
SmallVector<unsigned> sizePerThread;
SmallVector<unsigned> threadsPerWarp;
SmallVector<unsigned> warpsPerCTA;
SmallVector<unsigned> order;
std::optional<SmallVector<unsigned>> CTAsPerCGA;
std::optional<SmallVector<unsigned>> CTASplitNum;
std::optional<SmallVector<unsigned>> CTAOrder;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "sizePerThread") {
if (parseIntArrayAttr(parser, attr, sizePerThread,
"number of elements per thread")
.failed())
return {};
} else if (attr.getName() == "threadsPerWarp") {
if (parseIntArrayAttr(parser, attr, threadsPerWarp,
"number of threads per warp")
.failed())
return {};
} else if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA,
"number of warps per CTA")
.failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else if (attr.getName() == "CTAsPerCGA") {
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
.failed())
return {};
} else if (attr.getName() == "CTASplitNum") {
if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum")
.failed())
return {};
} else if (attr.getName() == "CTAOrder") {
if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
parser, CTAsPerCGA, CTASplitNum, CTAOrder, sizePerThread.size());
if (!CTALayout.has_value())
return {};
return parser.getChecked<BlockedEncodingAttr>(parser.getContext(),
sizePerThread, threadsPerWarp,
warpsPerCTA, order, *CTALayout);
}
void BlockedEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "sizePerThread = [" << ArrayRef(getSizePerThread()) << "]"
<< ", threadsPerWarp = [" << ArrayRef(getThreadsPerWarp()) << "]"
<< ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]"
<< ", order = [" << getOrder() << "]";
maybePrintCTALayout(getContext(), printer, getCTALayout(),
getSizePerThread().size());
printer << "}>";
}
LogicalResult
LinearEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
LinearLayout linearLayout) {
static const auto expectedInDims =
SmallVector<std::string>({"register", "lane", "warp", "block"});
for (const auto &[i, dims] : llvm::enumerate(
llvm::zip(linearLayout.getInDimNames(), expectedInDims))) {
const auto &[dim, expectedDimStr] = dims;
if (dim.str() != expectedDimStr) {
return emitError() << "Expected input dimension " << i << " to be '"
<< expectedDimStr << "'. Got " << dim;
}
}
for (auto [i, dim] : llvm::enumerate(linearLayout.getOutDimNames())) {
if (dim.str() != ("dim" + llvm::Twine(i)).str()) {
return emitError()
<< "Expected output dimensions to be ['dim0', 'dim1', ...]. Got "
<< dim << " at position " << i;
}
}
const auto &bases = linearLayout.getBases();
auto nonZero = [](auto val) { return val != 0; };
for (const auto &dimBases : llvm::make_second_range(bases)) {
if (!llvm::all_of(dimBases, [&](const auto &basis) {
return std::count_if(basis.begin(), basis.end(), nonZero) <= 1;
})) {
return emitError()
<< "In a distributed layout, each base must move in at most one "
"dimension.";
}
}
return success();
}
SmallVector<unsigned> BlockedEncodingAttr::getRepOrder() const {
return SmallVector<unsigned>(getOrder());
}
void LinearEncodingAttr::print(mlir::AsmPrinter &printer) const {
auto ll = getLinearLayout();
printer << "<{" << join(ll.getBases(), ", ", [](const auto &base) {
return base.first.str() + " = " + "[" +
join(base.second, ", ",
[](const std::vector<int32_t> &vec) {
return "[" + join(vec, ", ") + "]";
}) +
"]";
}) << "}>";
}
Attribute LinearEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
LinearLayout::BasesT bases;
std::vector<std::string> inDimNames = {"register", "lane", "warp", "block"};
for (const auto &inDimNameStr : inDimNames) {
auto inDimName = StringAttr::get(parser.getContext(), inDimNameStr);
Attribute value = dict.get(inDimName);
auto arrayOfArraysAttr = mlir::dyn_cast<ArrayAttr>(value);
if (!arrayOfArraysAttr) {
parser.emitError(parser.getCurrentLocation(),
"Expected array of arrays for basis of '")
<< inDimName.getValue() << "'";
return {};
}
std::vector<std::vector<int32_t>> inDimBases;
for (Attribute arrayAttr : arrayOfArraysAttr) {
auto intArrayAttr = mlir::dyn_cast<ArrayAttr>(arrayAttr);
if (!intArrayAttr) {
parser.emitError(parser.getCurrentLocation(),
"Expected array of integers in basis for '")
<< inDimName.getValue() << "'";
return {};
}
std::vector<int32_t> basis;
for (Attribute intAttr : intArrayAttr) {
auto intValueAttr = mlir::dyn_cast<IntegerAttr>(intAttr);
if (!intValueAttr) {
parser.emitError(parser.getCurrentLocation(),
"Expected integer in basis for '")
<< inDimName.getValue() << "'";
return {};
}
basis.push_back(intValueAttr.getInt());
}
inDimBases.push_back(std::move(basis));
}
bases[inDimName] = std::move(inDimBases);
}
size_t rank = 0;
for (const auto &basesDim : llvm::make_second_range(bases)) {
if (!basesDim.empty()) {
rank = basesDim[0].size();
break;
}
}
if (rank == 0) {
parser.emitError(parser.getCurrentLocation(), "Empty Layout not supported");
return {};
}
SmallVector<StringAttr> outDimNames;
for (int i = 0; i < rank; ++i) {
outDimNames.push_back(
StringAttr::get(parser.getContext(), "dim" + llvm::Twine(i)));
}
LinearLayout linearLayout(std::move(bases), std::move(outDimNames));
return parser.getChecked<LinearEncodingAttr>(parser.getContext(),
std::move(linearLayout));
}
static SmallVector<unsigned>
basesPerDimImpl(const LinearLayout::BasesT &namedBases, StringAttr dimName,
size_t rank, bool skipBroadcast = true) {
const auto &bases = namedBases.find(dimName)->second;
if (bases.empty()) {
return SmallVector<unsigned>(rank, 1);
}
SmallVector<unsigned> ret(rank, 1);
auto nonZero = [](auto val) { return val != 0; };
int nonZeroIdx = 0;
for (const auto &basis : bases) {
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
if (it != basis.end()) {
nonZeroIdx = it - basis.begin();
ret[nonZeroIdx] *= 2;
} else if (!skipBroadcast) {
ret[nonZeroIdx] *= 2;
}
}
return ret;
}
SmallVector<unsigned>
LinearEncodingAttr::basesPerDim(StringAttr dimName, bool skipBroadcast) const {
auto ll = getLinearLayout();
auto rank = ll.getNumOutDims();
return basesPerDimImpl(ll.getBases(), dimName, rank, skipBroadcast);
}
SmallVector<unsigned>
LinearEncodingAttr::orderPerDim(StringAttr dimName,
ArrayRef<unsigned> defaultOrder) const {
auto ll = getLinearLayout();
const auto &bases = ll.getBases().find(dimName)->second;
llvm::SetVector<unsigned> order;
auto nonZero = [](auto val) { return val != 0; };
for (const auto &basis : bases) {
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
if (it != basis.end()) {
auto i = it - basis.begin();
order.insert(i);
}
}
for (auto i : defaultOrder) {
order.insert(i);
}
return SmallVector<unsigned>(order.begin(), order.end());
}
SmallVector<unsigned> LinearEncodingAttr::getRepOrder() const {
return getOrder();
}
SmallVector<unsigned> LinearEncodingAttr::getCTAsPerCGA() const {
return basesPerDim(StringAttr::get(getContext(), "block"),
false);
}
SmallVector<unsigned> LinearEncodingAttr::getCTAOrder() const {
return orderPerDim(StringAttr::get(getContext(), "block"), getOrder());
}
SmallVector<unsigned> LinearEncodingAttr::getCTASplitNum() const {
return basesPerDim(StringAttr::get(getContext(), "block"));
}
SmallVector<unsigned> LinearEncodingAttr::getWarpsPerCTA() const {
return basesPerDim(StringAttr::get(getContext(), "warp"));
}
SmallVector<unsigned> LinearEncodingAttr::getWarpOrder() const {
return orderPerDim(StringAttr::get(getContext(), "warp"), getOrder());
}
SmallVector<unsigned> LinearEncodingAttr::getThreadsPerWarp() const {
return basesPerDim(StringAttr::get(getContext(), "lane"));
}
SmallVector<unsigned> LinearEncodingAttr::getThreadOrder() const {
return orderPerDim(StringAttr::get(getContext(), "lane"), getOrder());
}
SmallVector<unsigned> LinearEncodingAttr::getSizePerThread() const {
auto rank = getOrder().size();
auto ll = getLinearLayout();
auto ctx = getContext();
auto kRegister = StringAttr::get(ctx, "register");
llvm::SmallVector<unsigned> ctaShape;
for (auto [shape, cgaNum] :
llvm::zip(ll.getOutDimSizes(), getCTASplitNum())) {
ctaShape.push_back(shape / cgaNum);
}
LinearLayout::BasesT bases = ll.getBases();
llvm::SetVector<unsigned> reverseRepOrder;
auto nonZero = [](auto val) { return val != 0; };
auto ®isters = bases[kRegister];
while (!registers.empty()) {
auto &basis = registers.back();
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
if (it == basis.end()) {
break;
}
auto dim = it - basis.begin();
reverseRepOrder.insert(dim);
if (dim != reverseRepOrder.back() || 2 * basis[dim] != ctaShape[dim]) {
break;
}
ctaShape[dim] /= 2;
registers.pop_back();
}
return basesPerDimImpl(bases, kRegister, rank);
}
SmallVector<unsigned> LinearEncodingAttr::getOrder() const {
auto rank = getLinearLayout().getNumOutDims();
SmallVector<unsigned> order(rank);
std::iota(order.rbegin(), order.rend(), 0);
return orderPerDim(StringAttr::get(getContext(), "register"), order);
}
LinearLayout LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
auto ll = getLinearLayout();
auto canonicalDims = llvm::to_vector(ll.getOutDimNames());
llvm::SmallDenseMap<StringAttr, int64_t> namedShape;
llvm::SmallVector<StringAttr> permutedDims;
for (auto dim : getRepOrder()) {
permutedDims.push_back(canonicalDims[dim]);
namedShape[canonicalDims[dim]] = shape[dim];
}
ll = ll.transposeOuts(permutedDims);
ll = ensureLayoutNotSmallerThan(ll, namedShape);
ll = ensureLayoutNotLargerThan(ll, namedShape, false);
ll = ll.transposeOuts(canonicalDims);
return ll;
}
SmallVector<unsigned>
LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
auto scaledLayout = get(getContext(), toLinearLayout(shape));
auto kRegister = StringAttr::get(getContext(), "register");
return scaledLayout.basesPerDim(kRegister, false);
}
SmallVector<unsigned>
LinearEncodingAttr::getContig(const char *inDim,
SmallVector<unsigned int> lowerContig) const {
auto ll = getLinearLayout();
const auto &bases =
ll.getBases().find(StringAttr::get(getContext(), inDim))->second;
auto order = getOrder();
auto rank = order.size();
SmallVector<unsigned> contig(lowerContig);
auto basisIt = bases.begin();
for (unsigned dim : order) {
std::vector<int32_t> basis(rank, 0);
basis[dim] = contig[dim];
while (basisIt != bases.end() && *basisIt == basis) {
contig[dim] *= 2;
basis[dim] *= 2;
++basisIt;
}
}
return contig;
}
SmallVector<unsigned> LinearEncodingAttr::getContigPerThread() const {
SmallVector<unsigned> contig(getOrder().size(), 1);
return getContig("register", contig);
}
SmallVector<unsigned> LinearEncodingAttr::getContigPerWarp() const {
return getContig("lane", getContigPerThread());
}
unsigned
LinearEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape) const {
return product(getElemsPerThread(shape));
}
Attribute NvidiaMmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned versionMajor = 0;
unsigned versionMinor = 0;
SmallVector<unsigned> warpsPerCTA;
std::optional<SmallVector<unsigned>> CTAsPerCGA;
std::optional<SmallVector<unsigned>> CTASplitNum;
std::optional<SmallVector<unsigned>> CTAOrder;
SmallVector<unsigned> instrShape;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "versionMajor") {
if (parseUInt(parser, attr, versionMajor, "versionMajor").failed())
return {};
}
if (attr.getName() == "versionMinor") {
if (parseUInt(parser, attr, versionMinor, "versionMinor").failed())
return {};
}
if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
}
if (attr.getName() == "CTAsPerCGA") {
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
.failed())
return {};
}
if (attr.getName() == "CTASplitNum") {
if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum")
.failed())
return {};
}
if (attr.getName() == "CTAOrder") {
if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder")
.failed())
return {};
}
if (attr.getName() == "instrShape") {
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed()) {
return {};
}
}
}
std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
parser, CTAsPerCGA, CTASplitNum, CTAOrder, warpsPerCTA.size());
if (!CTALayout.has_value())
return {};
return parser.getChecked<NvidiaMmaEncodingAttr>(
parser.getContext(), versionMajor, versionMinor, warpsPerCTA, *CTALayout,
instrShape);
}
void NvidiaMmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "versionMajor = " << getVersionMajor()
<< ", versionMinor = " << getVersionMinor()
<< ", warpsPerCTA = [" << ArrayRef(getWarpsPerCTA()) << "]";
maybePrintCTALayout(getContext(), printer, getCTALayout(),
getRank());
printer << ", instrShape = [" << getInstrShape() << "]}>";
}
Attribute AMDMfmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned version = 0;
SmallVector<unsigned> warpsPerCTA;
SmallVector<unsigned> tilesPerWarp;
SmallVector<unsigned> instrShape;
bool isTransposed;
std::optional<SmallVector<unsigned>> CTAsPerCGA;
std::optional<SmallVector<unsigned>> CTASplitNum;
std::optional<SmallVector<unsigned>> CTAOrder;
std::optional<Type> elementType;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "version") {
if (parseUInt(parser, attr, version, "verison").failed())
return {};
}
if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
}
if (attr.getName() == "tilesPerWarp") {
if (parseIntArrayAttr(parser, attr, tilesPerWarp, "tilesPerWarp")
.failed())
return {};
}
if (attr.getName() == "instrShape") {
if (parseIntArrayAttr(parser, attr, instrShape, "instrShape").failed())
return {};
}
if (attr.getName() == "isTransposed") {
if (parseBool(parser, attr, isTransposed, "isTransposed").failed())
return {};
}
if (attr.getName() == "CTAsPerCGA") {
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
.failed())
return {};
}
if (attr.getName() == "CTASplitNum") {
if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum")
.failed())
return {};
}
if (attr.getName() == "CTAOrder") {
if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder")
.failed())
return {};
}
if (attr.getName() == "elementType") {
if (parseType(parser, attr, elementType, "elementType").failed())
return {};
}
}
if (tilesPerWarp.empty()) {
tilesPerWarp.resize(warpsPerCTA.size(), 1);
}
std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
parser, CTAsPerCGA, CTASplitNum, CTAOrder, warpsPerCTA.size());
if (!CTALayout.has_value())
return {};
return parser.getChecked<AMDMfmaEncodingAttr>(
parser.getContext(), version, warpsPerCTA, tilesPerWarp, instrShape[0],
instrShape[1], isTransposed, *CTALayout, elementType);
}
void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "version = " << getVersion()
<< ", warpsPerCTA = [" << getWarpsPerCTA() << "]";
auto tilesPerWarp = getTilesPerWarp();
if (!hasUnitTilesPerWarp()) {
printer << ", tilesPerWarp = [" << getTilesPerWarp() << "]";
}
printer << ", instrShape = [" << ArrayRef{getMDim(), getNDim()} << "]"
<< ", isTransposed = " << getIsTransposed();
maybePrintCTALayout(getContext(), printer, getCTALayout(),
getRank());
if (getElementType() && !(getElementType()->isF32())) {
std::string typeStr;
llvm::raw_string_ostream rso(typeStr);
getElementType()->print(rso);
printer << ", elementType = " << rso.str();
}
printer << "}>";
}
LogicalResult AMDMfmaEncodingAttr::verify(
function_ref<mlir::InFlightDiagnostic()> emitError, unsigned version,
llvm::ArrayRef<unsigned int> warpsPerCTA,
llvm::ArrayRef<unsigned int> tilesPerWarp, unsigned mDim, unsigned nDim,
bool isTransposed, mlir::triton::gpu::CTALayoutAttr,
std::optional<Type> elementType) {
if (!(version >= 0 && version <= 4)) {
return emitError() << "version must be in the [0, 4] range";
}
const std::array<std::pair<unsigned, unsigned>, 4> validDims = {
{{32, 32}, {16, 16}, {64, 4}, {4, 64}}};
if (!llvm::is_contained(validDims, std::make_pair(mDim, nDim))) {
return emitError() << "invalid (mDim, nDim) combination: (" << mDim << ", "
<< nDim << ")";
}
if (elementType && !(elementType->isF64() || elementType->isF32() ||
elementType->isInteger(32))) {
std::string typeStr;
llvm::raw_string_ostream rso(typeStr);
elementType->print(rso);
return emitError() << "element type must be f64, f32, i32, or none";
}
return success();
}
Attribute AMDWmmaEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned version = 0;
bool isTransposed = false;
SmallVector<unsigned> warpsPerCTA;
std::optional<SmallVector<unsigned>> CTAsPerCGA;
std::optional<SmallVector<unsigned>> CTASplitNum;
std::optional<SmallVector<unsigned>> CTAOrder;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "version") {
if (parseUInt(parser, attr, version, "version").failed())
return {};
}
if (attr.getName() == "isTranspose") {
if (parseBool(parser, attr, isTransposed, "isTranspose").failed())
return {};
}
if (attr.getName() == "warpsPerCTA") {
if (parseIntArrayAttr(parser, attr, warpsPerCTA, "warpsPerCTA").failed())
return {};
}
if (attr.getName() == "CTAsPerCGA") {
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
.failed())
return {};
}
if (attr.getName() == "CTASplitNum") {
if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum")
.failed())
return {};
}
if (attr.getName() == "CTAOrder") {
if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder")
.failed())
return {};
}
}
std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
parser, CTAsPerCGA, CTASplitNum, CTAOrder, warpsPerCTA.size());
if (!CTALayout.has_value())
return {};
return parser.getChecked<AMDWmmaEncodingAttr>(
parser.getContext(), version, isTransposed, warpsPerCTA, *CTALayout);
}
void AMDWmmaEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "version = " << getVersion()
<< ", isTranspose = " << getIsTransposed() << ", warpsPerCTA = ["
<< ArrayRef(getWarpsPerCTA()) << "]";
maybePrintCTALayout(getContext(), printer, getCTALayout(),
getWarpsPerCTA().size());
printer << "}>";
}
LogicalResult
AMDWmmaEncodingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
unsigned version, bool isTransposed,
llvm::ArrayRef<unsigned int> warpsPerCTA,
mlir::triton::gpu::CTALayoutAttr) {
if (version != 1 && version != 2) {
return emitError() << "WMMA version must be in the [1, 2] range";
}
if (version != 2 && isTransposed) {
return emitError() << "Transposed WMMA is supported only for version 2";
}
return success();
}
Attribute SliceEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
NamedAttrList attrs;
if (parser.parseOptionalAttrDict(attrs).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned dim = mlir::cast<IntegerAttr>(attrs.get("dim")).getInt();
auto parent = mlir::dyn_cast<DistributedEncodingTrait>(attrs.get("parent"));
if (!parent) {
parser.emitError(parser.getNameLoc(),
"expected a distributed encoding trait");
return {};
}
return parser.getChecked<SliceEncodingAttr>(parser.getContext(), dim, parent);
}
void SliceEncodingAttr::print(mlir::AsmPrinter &printer) const {
printer << "<{"
<< "dim = " << getDim() << ", "
<< "parent = " << getParent() << "}>";
}
LogicalResult
SliceEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned dim, DistributedEncodingTrait parent) {
unsigned rank = cast<LayoutEncodingTrait>(parent).getRank();
if (rank <= 1)
return emitError() << "parent layout must have at least rank >= 2";
if (dim >= rank) {
return emitError() << "slice dim=" << dim
<< " must be less than the parent rank=" << rank;
}
return success();
}
SmallVector<unsigned> SliceEncodingAttr::getRepOrder() const {
auto parentRepOrder = getParent().getRepOrder();
return eraseOrder(parentRepOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getCTASplitNum() const {
SmallVector<unsigned> res = ::getCTASplitNum(getParent());
res.erase(res.begin() + getDim());
return res;
}
SmallVector<unsigned> SliceEncodingAttr::getCTAOrder() const {
auto parentCTAOrder = ::getCTAOrder(getParent());
return eraseOrder(parentCTAOrder, getDim());
}
SmallVector<unsigned> SliceEncodingAttr::getCTAsPerCGA() const {
auto parentCTAsPerCGA = ::getCTAsPerCGA(getParent());
if (parentCTAsPerCGA[getDim()] == 1) {
parentCTAsPerCGA.erase(parentCTAsPerCGA.begin() + getDim());
return parentCTAsPerCGA;
}
* (1) Return CTAsPerCGA of its parent. This is not a perfect solution
* because the rank of the returned CTAsPerCGA does not match the rank of
* tensorShape.
* (2) Get CTAsPerCGA of its parent and erase the sliced dim. This is not a
* perfect solution because the product of the returned CTAsPerCGA might not
* match numCTAs.
* To avoid introducing inconsistencies to the shape and
* layout system, the usage of directly getting CTAsPerCGA of a slice layout
* in which the sliced dim is not 1 is banned. You should always consider
* slice layout as a special case and use getCTAsPerCGA(layout.getParent())
* in the branch where layout is an instance of SliceEncodingAttr. This is
* inconvenient but safe.
*/
llvm::report_fatal_error(
"getCTAsPerCGA for SliceEncodingAttr is not well-defined");
}
template <class T>
SmallVector<T> SliceEncodingAttr::paddedShape(ArrayRef<T> shape) const {
size_t rank = shape.size();
unsigned dim = getDim();
SmallVector<T> retShape(rank + 1);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d < dim)
retShape[d] = shape[d];
else if (d == dim)
retShape[d] = 1;
else
retShape[d] = shape[d - 1];
}
return retShape;
}
template SmallVector<unsigned>
SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
template SmallVector<int64_t>
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
std::optional<CTALayoutAttr>
parseCTAAttrs(AsmParser &parser, NamedAttrList attrList, unsigned rank) {
std::optional<SmallVector<unsigned>> CTAsPerCGA;
std::optional<SmallVector<unsigned>> CTASplitNum;
std::optional<SmallVector<unsigned>> CTAOrder;
for (const NamedAttribute &attr : attrList) {
if (attr.getName() == "CTAsPerCGA") {
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
.failed())
return {};
} else if (attr.getName() == "CTASplitNum") {
if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum")
.failed())
return {};
} else if (attr.getName() == "CTAOrder") {
if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
return getCTALayoutOrError(parser, CTAsPerCGA, CTASplitNum, CTAOrder, rank);
}
template <typename SpecificEncoding>
Attribute parseSwizzledEncoding(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned vec = 0;
unsigned perPhase = 0;
unsigned maxPhase = 0;
SmallVector<unsigned> order;
NamedAttrList remainingAttrs;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "vec") {
if (parseUInt(parser, attr, vec, "vec").failed())
return {};
} else if (attr.getName() == "perPhase") {
if (parseUInt(parser, attr, perPhase, "perPhase").failed())
return {};
} else if (attr.getName() == "maxPhase") {
if (parseUInt(parser, attr, maxPhase, "maxPhase").failed())
return {};
} else if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else {
remainingAttrs.push_back(attr);
}
}
if (auto CTALayout = parseCTAAttrs(parser, remainingAttrs, order.size()))
return parser.getChecked<SpecificEncoding>(
parser.getContext(), vec, perPhase, maxPhase, order, *CTALayout);
return {};
}
LogicalResult
SwizzledSharedEncodingAttr::verify(function_ref<InFlightDiagnostic()> emitError,
unsigned vec, unsigned perPhase,
unsigned maxPhase, ArrayRef<unsigned> order,
CTALayoutAttr ctaLayout) {
if (order.size() != ctaLayout.getRank()) {
return emitError() << "order size (" << order.size()
<< ") must match CTALayout rank (" << ctaLayout.getRank()
<< ")";
}
return verifyLayoutOrder(emitError, order);
}
Attribute SwizzledSharedEncodingAttr::parse(AsmParser &parser, Type type) {
return parseSwizzledEncoding<SwizzledSharedEncodingAttr>(parser, type);
}
void SwizzledSharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "vec = " << getVec()
<< ", perPhase = " << getPerPhase()
<< ", maxPhase = " << getMaxPhase()
<< ", order = [" << getOrder() << "]";
maybePrintCTALayout(getContext(), printer, getCTALayout(),
getOrder().size());
printer << "}>";
}
Attribute PaddedSharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (failed(parser.parseLess()) || failed(parser.parseLSquare()))
return {};
SmallVector<unsigned, 4> intervals, paddings;
auto parseIntervalPaddingPair = [&]() {
unsigned interval = 0, padding = 0;
if (failed(parser.parseInteger(interval)) || failed(parser.parseColon()) ||
failed(parser.parsePlus()) || failed(parser.parseInteger(padding)))
return failure();
intervals.push_back(interval);
paddings.push_back(padding);
return success();
};
if (failed(parser.parseCommaSeparatedList(parseIntervalPaddingPair)) ||
failed(parser.parseRSquare()))
return {};
NamedAttrList attrList;
if (failed(parser.parseOptionalAttrDict(attrList)) ||
failed(parser.parseGreater()))
return {};
SmallVector<unsigned> order;
NamedAttrList remainingAttrs;
for (const NamedAttribute &attr : attrList) {
if (attr.getName() == "order") {
if (parseIntArrayAttr(parser, attr, order, "order").failed())
return {};
} else {
remainingAttrs.push_back(attr);
}
}
if (auto ctaLayout = parseCTAAttrs(parser, remainingAttrs, order.size()))
return parser.getChecked<PaddedSharedEncodingAttr>(
parser.getContext(), intervals, paddings, order, *ctaLayout);
return {};
}
void PaddedSharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<[";
llvm::interleaveComma(llvm::zip(getIntervals(), getPaddings()), printer,
[&](std::tuple<unsigned, unsigned> intervalPad) {
printer << std::get<0>(intervalPad) << ":+"
<< std::get<1>(intervalPad);
});
printer << "] {order = [" << getOrder() << "]";
maybePrintCTALayout(getContext(), printer, getCTALayout(),
getOrder().size());
printer << "}>";
}
LogicalResult PaddedSharedEncodingAttr::verify(
function_ref<InFlightDiagnostic()> emitError, ArrayRef<unsigned> intervals,
ArrayRef<unsigned> paddings, ArrayRef<unsigned> order,
CTALayoutAttr ctaLayout) {
if (intervals.size() != paddings.size())
return emitError() << "intervals size (" << intervals.size()
<< ") must match paddings size (" << paddings.size()
<< ")";
if (intervals.empty())
return emitError() << "must have at least one interval-padding pair";
if (!llvm::all_of(intervals, llvm::isPowerOf2_32))
return emitError() << "interval values must all be power of two";
if (!llvm::all_of(paddings, llvm::isPowerOf2_32))
return emitError() << "padding values must all be power of two";
llvm::SmallSet<unsigned, 4> intervalValues(intervals.begin(),
intervals.end());
if (intervalValues.size() != intervals.size())
return emitError() << "interval values cannot have duplicates";
if (order.empty())
return emitError() << "order cannot be empty";
if (order.size() != ctaLayout.getRank())
return emitError() << "order size (" << order.size()
<< ") must match CTALayout rank (" << ctaLayout.getRank()
<< ")";
return verifyLayoutOrder(emitError, order);
}
PaddedSharedEncodingAttr PaddedSharedEncodingAttr::get(
MLIRContext *context, ArrayRef<std::pair<unsigned, unsigned>> intervalPads,
ArrayRef<unsigned> order, CTALayoutAttr ctaLayout) {
SmallVector<unsigned> intervals, paddings;
intervals.reserve(intervalPads.size());
paddings.reserve(intervalPads.size());
for (auto [interval, padding] : intervalPads) {
intervals.push_back(interval);
paddings.push_back(padding);
}
return get(context, intervals, paddings, order, ctaLayout);
}
int64_t PaddedSharedEncodingAttr::getPaddedSize(ArrayRef<int64_t> shape) const {
int64_t unpaddedSize = product(shape);
int64_t paddingSize = 0;
for (auto [interval, padding] :
llvm::zip_equal(getIntervals(), getPaddings())) {
paddingSize += (unpaddedSize >> llvm::Log2_32(interval))
<< llvm::Log2_32(padding);
if (unpaddedSize % interval == 0)
paddingSize -= padding;
}
return unpaddedSize + paddingSize;
}
Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) {
if (parser.parseLess().failed())
return {};
DictionaryAttr dict;
if (parser.parseAttribute(dict).failed())
return {};
if (parser.parseGreater().failed())
return {};
unsigned swizzlingByteWidth;
bool transposed = false;
bool fp4Padded = false;
unsigned elementBitWidth;
std::optional<SmallVector<unsigned>> CTAsPerCGA;
std::optional<SmallVector<unsigned>> CTASplitNum;
std::optional<SmallVector<unsigned>> CTAOrder;
for (const NamedAttribute &attr : dict) {
if (attr.getName() == "swizzlingByteWidth") {
if (parseUInt(parser, attr, swizzlingByteWidth, "swizzlingByteWidth")
.failed())
return {};
} else if (attr.getName() == "transposed") {
if (parseBool(parser, attr, transposed, "transposed").failed())
return {};
} else if (attr.getName() == "elementBitWidth") {
if (parseUInt(parser, attr, elementBitWidth, "elementBitWidth").failed())
return {};
} else if (attr.getName() == "fp4Padded") {
if (parseBool(parser, attr, fp4Padded, "fp4Padded").failed())
return {};
} else if (attr.getName() == "CTAsPerCGA") {
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
.failed())
return {};
} else if (attr.getName() == "CTASplitNum") {
if (parseIntArrayAttr(parser, attr, CTASplitNum.emplace(), "CTASplitNum")
.failed())
return {};
} else if (attr.getName() == "CTAOrder") {
if (parseIntArrayAttr(parser, attr, CTAOrder.emplace(), "CTAOrder")
.failed())
return {};
} else {
parser.emitError(parser.getNameLoc(), "unexpected key: ")
<< attr.getName().strref();
return {};
}
}
std::optional<CTALayoutAttr> CTALayout = getCTALayoutOrError(
parser, CTAsPerCGA, CTASplitNum, CTAOrder, 2);
if (!CTALayout.has_value())
return {};
return parser.getChecked<NVMMASharedEncodingAttr>(
parser.getContext(), swizzlingByteWidth, transposed, elementBitWidth,
fp4Padded, *CTALayout);
}
void NVMMASharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "swizzlingByteWidth = " << getSwizzlingByteWidth()
<< ", transposed = " << getTransposed()
<< ", elementBitWidth = " << getElementBitWidth();
if (getFp4Padded()) {
printer << ", fp4Padded = true";
}
maybePrintCTALayout(getContext(), printer, getCTALayout(),
2);
printer << "}>";
}
int NVMMASharedEncodingAttr::getVec() const {
if (getSwizzlingByteWidth() == 0)
return 1;
return 128 / getElementBitWidth();
}
int NVMMASharedEncodingAttr::getPerPhase() const {
if (getSwizzlingByteWidth() == 0)
return 1;
return 128 / getSwizzlingByteWidth();
}
int NVMMASharedEncodingAttr::getMaxPhase() const {
if (getSwizzlingByteWidth() == 0)
return 1;
return getSwizzlingByteWidth() / 16;
}
int32_t NVMMASharedEncodingAttr::getAlignment() const {
return 128 * getMaxPhase();
}
Attribute AMDRotatingSharedEncodingAttr::parse(AsmParser &parser, Type type) {
return parseSwizzledEncoding<AMDRotatingSharedEncodingAttr>(parser, type);
}
void AMDRotatingSharedEncodingAttr::print(AsmPrinter &printer) const {
printer << "<{"
<< "vec = " << getVec()
<< ", perPhase = " << getPerPhase()
<< ", maxPhase = " << getMaxPhase()
<< ", order = [" << getOrder() << "]";
maybePrintCTALayout(getContext(), printer, getCTALayout(),
getOrder().size());
printer << "}>";
}
bool AMDMfmaEncodingAttr::hasUnitTilesPerWarp() const {
return !llvm::any_of(getTilesPerWarp(), [](int x) { return x != 1; });
}
SmallVector<int64_t>
AMDMfmaEncodingAttr::getInstrShapeForOperand(int kWidth, int opIdx) const {
unsigned mDim = getMDim();
unsigned nDim = getNDim();
assert((mDim == nDim) && (mDim == 32 || mDim == 16 || mDim == 4) ||
(mDim == 64 && nDim == 4) || (mDim == 4 && nDim == 64));
constexpr int warpSize = 64;
int kGroups = warpSize / std::min(mDim, nDim);
int64_t kDim = kWidth * kGroups;
if (opIdx == 0)
return {mDim, kDim};
else
assert(opIdx == 1);
return {kDim, nDim};
}
SmallVector<unsigned> AMDMfmaEncodingAttr::getRepOrder() const {
return getMatrixOrder(getRank(), true);
}
SmallVector<unsigned>
AMDMfmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
return getOrderForDotOperand(opIdx, getRank(), true);
}
SmallVector<int64_t>
AMDMfmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape,
int kWidth, int opIdx) const {
auto operandTileShape = getInstrShapeForOperand(kWidth, opIdx);
auto rank = operandShape.size();
auto warpsPerCTA = getWarpsPerCTA();
auto tilesPerWarp = getTilesPerWarp();
int numRepBatch =
rank == 3 ? std::max<int64_t>(1, operandShape[0] / warpsPerCTA[0]) : 1;
if (opIdx == 0)
return {
numRepBatch,
std::max<int64_t>(1, operandShape[rank - 2] /
(operandTileShape[0] * tilesPerWarp[rank - 2] *
warpsPerCTA[rank - 2])) *
tilesPerWarp[rank - 2],
std::max<int64_t>(1, operandShape[rank - 1] / operandTileShape[1])};
else {
assert(opIdx == 1);
return {
numRepBatch,
std::max<int64_t>(1, operandShape[rank - 2] / operandTileShape[0]),
std::max<int64_t>(1, operandShape[rank - 1] /
(operandTileShape[1] * tilesPerWarp[rank - 1] *
warpsPerCTA[rank - 1])) *
tilesPerWarp[rank - 1]};
}
}
SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand(
CTALayoutAttr ctaLayout, int operandIdx, ArrayRef<int64_t> operandShape,
ArrayRef<unsigned> sharedOrder, unsigned vectorSize, unsigned elemBitWidth,
bool needTrans) const {
int kDimIndex = operandIdx == 0 ? 1 : 0;
if (operandIdx >= 2) {
return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder,
ctaLayout);
}
if (needTrans)
kDimIndex = 1 - kDimIndex;
bool isKContig = sharedOrder[0] == kDimIndex;
bool isGFX950 = getVersion() == 4;
bool swizzleNonKContig =
isGFX950 && (elemBitWidth == 8 || elemBitWidth == 16);
if (!isKContig && !swizzleNonKContig) {
return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder,
ctaLayout);
}
const unsigned numBanks = isGFX950 ? 64 : 32;
const unsigned bankBitWidth = 32;
const unsigned simdWidth = 16;
int innerDimLength = operandShape[sharedOrder[0]];
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth;
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
int maxPhase =
std::max(std::min(simdWidth / perPhase, innerDimLength / vectorSize), 1u);
if (getMDim() == 4)
maxPhase = 4;
return SwizzledSharedEncodingAttr::get(getContext(), vectorSize, perPhase,
maxPhase, sharedOrder, ctaLayout);
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getRepOrder() const {
return getMatrixOrder(getRank(), true);
}
SmallVector<unsigned>
AMDWmmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
return getOrderForDotOperand(opIdx, getRank(), true);
}
SmallVector<int64_t>
AMDWmmaEncodingAttr::getElemsPerInstrForOperands(int kDim, int opIdx) const {
if (opIdx == 0)
return {16, kDim};
else
return {kDim, 16};
}
SmallVector<int64_t>
AMDWmmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> operandShape,
Type elemType, int kWidth, int kDim,
int opIdx) const {
auto operandTileShape = getElemsPerInstrForOperands(kDim, opIdx);
assert(operandTileShape.size() == 2);
auto warpsPerCTA = getWarpsPerCTA();
auto rank = operandShape.size();
assert(rank == 2 || rank == 3);
int numRepBatch =
rank == 3 ? std::max<int64_t>(1, operandShape[0] / warpsPerCTA[0]) : 1;
if (opIdx == 0)
return {
numRepBatch,
std::max<int64_t>(1, operandShape[rank - 2] /
(operandTileShape[0] * warpsPerCTA[rank - 2])),
std::max<int64_t>(1, operandShape[rank - 1] / operandTileShape[1])};
else {
assert(opIdx == 1);
return {
numRepBatch,
std::max<int64_t>(1, operandShape[rank - 2] / operandTileShape[0]),
std::max<int64_t>(1, operandShape[rank - 1] / (operandTileShape[1] *
warpsPerCTA[rank - 1]))};
}
}
SmallVector<unsigned> AMDWmmaEncodingAttr::getMNKDimPerInstr() {
return {16, 16, 16};
}
SwizzledSharedEncodingAttr AMDWmmaEncodingAttr::composeSharedLayoutForOperand(
CTALayoutAttr ctaLayout, int operandIdx, ArrayRef<int64_t> operandShape,
ArrayRef<unsigned> sharedOrder, unsigned kWidth, unsigned elemBitWidth,
bool needTrans) const {
int kDimIndex = operandIdx == 0 ? 1 : 0;
bool isKContig = sharedOrder[0] == kDimIndex;
if (!isKContig) {
return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder,
ctaLayout);
}
int vectorSize = std::min(kWidth * elemBitWidth, 128u) / elemBitWidth;
const int numBanks = 32;
const int bankBitWidth = 32;
int innerDimLength = operandShape[sharedOrder[0]];
int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth;
int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength);
int mDim = getMNKDimPerInstr()[0];
int maxPhase =
std::max(std::min(mDim / perPhase, innerDimLength / vectorSize), 1);
return SwizzledSharedEncodingAttr::get(getContext(), vectorSize, perPhase,
maxPhase, sharedOrder, ctaLayout);
}
bool NvidiaMmaEncodingAttr::isVolta() const { return getVersionMajor() == 1; }
bool NvidiaMmaEncodingAttr::isTuring() const {
return getVersionMajor() == 2 && getVersionMinor() == 1;
}
bool NvidiaMmaEncodingAttr::isAmpere() const { return getVersionMajor() == 2; }
bool NvidiaMmaEncodingAttr::isHopper() const { return getVersionMajor() == 3; }
SmallVector<unsigned> NvidiaMmaEncodingAttr::getRepOrder() const {
return getMatrixOrder(getRank(), true);
}
SmallVector<unsigned>
NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
return getOrderForDotOperand(opIdx, getRank(), true);
}
SmallVector<int64_t>
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
int kWidth, int opIdx) const {
assert(kWidth >= std::max(32 / bitwidth, 1) &&
"kWidth must be >= max(32 / bitwidth, 1) for this function to be "
"well-defined");
auto rank = shape.size();
auto warpsPerCTA = to_vector(getWarpsPerCTA());
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
warpsPerCTA[kDim] = 1;
SmallVector<int> tileSize;
if (rank == 3) {
tileSize.push_back(1);
}
auto tileBitWidthK = (isAmpere() && bitwidth == 64) ? (4 * 256) : (4 * 64);
if (opIdx == 0) {
tileSize.push_back(16);
tileSize.push_back(tileBitWidthK / bitwidth);
} else {
tileSize.push_back(tileBitWidthK / bitwidth);
tileSize.push_back(8);
}
SmallVector<int64_t> numRep;
if (rank != 3) {
numRep.push_back(1);
}
for (auto [s, size, warp] : llvm::zip(shape, tileSize, warpsPerCTA)) {
numRep.push_back(std::max<int64_t>(1, s / (size * warp)));
}
return numRep;
}
SmallVector<unsigned> DotOperandEncodingAttr::getRepOrder() const {
if (auto mma = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
return mma.getRepOrderForOperand(getOpIdx());
} else if (auto blocked = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
return to_vector(blocked.getOrder());
}
llvm::report_fatal_error(
"getRepOrder not implemented for DotOperandEncodingAttr");
return {};
}
SmallVector<unsigned> DotOperandEncodingAttr::getCTAsPerCGA() const {
return ::getCTAsPerCGA(getParent());
}
SmallVector<unsigned> DotOperandEncodingAttr::getCTAOrder() const {
return ::getCTAOrder(getParent());
}
SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
SmallVector<unsigned> res = ::getCTASplitNum(getParent());
auto rank = res.size();
assert(rank == 2 || rank == 3 && "Invalid dotLayout");
auto kDim = getOpIdx() == 0 ? rank - 1 : rank - 2;
res[kDim] = 1;
return res;
}
LogicalResult DotOperandEncodingAttr::verify(
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
unsigned opIdx, Attribute parent, unsigned kWidth) {
if (opIdx != 0 && opIdx != 1) {
return emitError() << "ttg.dot_op opIdx parameter can be 0 or 1, got: "
<< opIdx;
}
if (!parent) {
return emitError() << "ttg.dot_op parent parameter cannot be null";
}
if (auto parentAttr = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
if (kWidth != 0 && !(parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "ttg.dot_op kWidth parameter can only be "
"non-zero for Ampere or Hopper MMA parent";
if (kWidth == 0 && (parentAttr.isAmpere() || parentAttr.isHopper()))
return emitError() << "ttg.dot_op kWidth parameter is mandatory for "
"Ampere or Hopper MMA parent";
if (opIdx != 0 && parentAttr.isHopper())
return emitError()
<< "ttg.dot_op opIdx parameter must be 0 for "
"Hopper MMA parent, since Hopper WGMMA only allows first "
"operand to be in registers";
return success();
}
if (auto parentAttr = mlir::dyn_cast<AMDWmmaEncodingAttr>(parent)) {
if (kWidth != 8 && kWidth != 16 && parentAttr.getVersion() == 1 ||
kWidth != 4 && kWidth != 8 && kWidth != 16 &&
parentAttr.getVersion() == 2)
return emitError() << "ttg.dot_op kWidth parameter must be 8/16 for "
"gfx11 and 4/8/16 for gfx12 (including packed "
"cases for `scaled_dot`)";
return success();
}
if (auto parentAttr = mlir::dyn_cast<AMDMfmaEncodingAttr>(parent)) {
if (kWidth == 0)
return emitError() << "ttg.dot_op kWidth parameter is mandatory for "
"MFMA parent";
return success();
}
if (auto parentAttr = mlir::dyn_cast<BlockedEncodingAttr>(parent)) {
if (kWidth != 0)
return emitError() << "ttg.dot_op kWidth parameter is not supported "
"when the parent is a blocked layout";
return success();
}
return emitError() << "ttg.dot_op unexpected parent layout: " << parent;
}
class TritonGPUOpAsmInterface : public OpAsmDialectInterface {
public:
using OpAsmDialectInterface::OpAsmDialectInterface;
AliasResult getAlias(Attribute attr, raw_ostream &os) const override {
if (auto mmaAttr = mlir::dyn_cast<MmaEncodingTrait>(attr)) {
os << "mma";
return AliasResult::FinalAlias;
} else if (auto sharedAttr = mlir::dyn_cast<SharedEncodingTrait>(attr)) {
os << "shared";
return AliasResult::FinalAlias;
} else if (auto blockedAttr = mlir::dyn_cast<BlockedEncodingAttr>(attr)) {
os << "blocked";
return AliasResult::FinalAlias;
} else if (auto linearAttr = mlir::dyn_cast<LinearEncodingAttr>(attr)) {
os << "linear";
return AliasResult::FinalAlias;
}
os << "slice";
return AliasResult::FinalAlias;
} */
if (auto smem = mlir::dyn_cast<SharedMemorySpaceAttr>(attr)) {
os << "smem";
return AliasResult::FinalAlias;
}
return OpAsmDialectInterface::getAlias(attr, os);
}
};
struct TritonGPUInferLayoutInterface
: public triton::DialectInferLayoutInterface {
using DialectInferLayoutInterface::DialectInferLayoutInterface;
LogicalResult
inferReduceOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding,
std::optional<Location> loc) const override {
resultEncoding =
SliceEncodingAttr::get(getDialect()->getContext(), axis,
cast<DistributedEncodingTrait>(operandEncoding));
return success();
}
LogicalResult
inferTransOpEncoding(Attribute operandEncoding, ArrayRef<int64_t> shape,
ArrayRef<int32_t> order, Attribute &resultEncoding,
std::optional<Location> loc) const override {
if (isIota(order)) {
resultEncoding = operandEncoding;
return success();
}
if (shape.size() != order.size()) {
return emitOptionalError(loc, "shape and order rank do not match: ",
shape.size(), " vs ", order.size());
}
auto checkRank = [&](unsigned rank) {
if (rank != order.size()) {
return emitOptionalError(loc, "rank of encoding does not match order: ",
rank, " vs ", order.size());
}
return success();
};
auto *ctx = getDialect()->getContext();
auto invOrder = inversePermutation(order);
SmallVector<unsigned> invOrderUnsigned(invOrder.begin(), invOrder.end());
if (auto enc = dyn_cast<SwizzledSharedEncodingAttr>(operandEncoding)) {
if (failed(checkRank(enc.getRank())))
return failure();
CTALayoutAttr ctaLayout =
permuteCTALayout(ctx, enc.getCTALayout(), order);
resultEncoding = SwizzledSharedEncodingAttr::get(
ctx, enc.getVec(), enc.getPerPhase(), enc.getMaxPhase(),
applyPermutation(invOrderUnsigned, enc.getOrder()), ctaLayout);
return success();
}
if (auto enc = dyn_cast<NVMMASharedEncodingAttr>(operandEncoding)) {
if (failed(checkRank(enc.getRank())))
return failure();
if (order != ArrayRef<int32_t>({1, 0})) {
return emitOptionalError(
loc, "NVMMSharedEncoding can only be transposed in 2D");
}
CTALayoutAttr ctaLayout =
permuteCTALayout(ctx, enc.getCTALayout(), order);
resultEncoding = NVMMASharedEncodingAttr::get(
ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(),
enc.getElementBitWidth(), enc.getFp4Padded(), ctaLayout);
return success();
}
if (auto enc = dyn_cast<BlockedEncodingAttr>(operandEncoding)) {
if (failed(checkRank(enc.getRank())))
return failure();
CTALayoutAttr ctaLayout =
permuteCTALayout(ctx, enc.getCTALayout(), order);
resultEncoding = BlockedEncodingAttr::get(
ctx, applyPermutation(enc.getSizePerThread(), order),
applyPermutation(enc.getThreadsPerWarp(), order),
applyPermutation(enc.getWarpsPerCTA(), order),
applyPermutation(invOrderUnsigned, enc.getOrder()), ctaLayout);
return success();
}
if (auto enc = dyn_cast<PaddedSharedEncodingAttr>(operandEncoding)) {
if (failed(checkRank(enc.getRank())))
return failure();
CTALayoutAttr ctaLayout =
permuteCTALayout(ctx, enc.getCTALayout(), order);
resultEncoding = PaddedSharedEncodingAttr::get(
ctx, enc.getIntervals(), enc.getPaddings(),
applyPermutation(invOrderUnsigned, enc.getOrder()), ctaLayout);
return success();
}
auto ll = toLinearLayout(shape, operandEncoding);
auto transposedLl = transposeLinearLayout(ll, order);
resultEncoding = LinearEncodingAttr::get(ctx, std::move(transposedLl));
return success();
}
LogicalResult
inferExpandDimsOpEncoding(Attribute operandEncoding, unsigned axis,
Attribute &resultEncoding,
std::optional<Location> location) const override {
auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(operandEncoding);
if (!sliceEncoding)
return emitOptionalError(
location, "ExpandDimsOp operand encoding must be SliceEncodingAttr");
if (sliceEncoding.getDim() != axis)
return emitOptionalError(
location, "Incompatible slice dimension for ExpandDimsOp operand");
resultEncoding = sliceEncoding.getParent();
return success();
}
LogicalResult
inferDotOpEncoding(Attribute operandEncoding, unsigned opIdx,
Attribute retEncoding,
std::optional<Location> location) const override {
auto mmaRetEncoding = mlir::dyn_cast<NvidiaMmaEncodingAttr>(retEncoding);
if (mmaRetEncoding && mmaRetEncoding.isHopper()) {
auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(operandEncoding);
if (!mlir::isa<NVMMASharedEncodingAttr>(operandEncoding) &&
!(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 &&
mlir::isa<NvidiaMmaEncodingAttr>(dotOpEnc.getParent()))) {
return emitOptionalError(
location, "unexpected operand layout for NvidiaMmaEncodingAttr v3");
}
} else if (auto dotOpEnc =
mlir::dyn_cast<DotOperandEncodingAttr>(operandEncoding)) {
if (opIdx != dotOpEnc.getOpIdx())
return emitOptionalError(location, "Wrong opIdx");
if (retEncoding != dotOpEnc.getParent())
return emitOptionalError(location, "Incompatible parent encoding");
} else
return emitOptionalError(
location, "Dot's a/b's encoding should be of DotOperandEncodingAttr");
return success();
}
LogicalResult
verifyDotOpEncodingCompatibility(Operation *op, Attribute operandEncodingA,
Attribute operandEncodingB) const override {
auto aEncoding =
mlir::dyn_cast<triton::gpu::DotOperandEncodingAttr>(operandEncodingA);
auto bEncoding =
mlir::dyn_cast<triton::gpu::DotOperandEncodingAttr>(operandEncodingB);
if (!aEncoding && !bEncoding)
return mlir::success();
auto mmaAEncoding =
mlir::dyn_cast_or_null<NvidiaMmaEncodingAttr>(aEncoding.getParent());
if (mmaAEncoding && mmaAEncoding.isHopper())
return success();
if (!aEncoding || !bEncoding)
return op->emitError("mismatching encoding between A and B operands");
if (aEncoding.getKWidth() != bEncoding.getKWidth())
return op->emitError("mismatching kWidth between A and B operands");
return success();
}
LogicalResult inferReshapeOpLegacyEncoding(ArrayRef<int64_t> srcShape,
Attribute srcEnc,
ArrayRef<int64_t> dstShape,
Attribute &dstEnc) const {
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
if (!src) {
return failure();
}
if (srcShape == dstShape) {
dstEnc = srcEnc;
return success();
}
auto context = srcEnc.getContext();
int32_t numWarps = product(src.getWarpsPerCTA());
int32_t threadsPerWarp = product(src.getThreadsPerWarp());
int32_t numCTAs = product(src.getCTALayout().getCTAsPerCGA());
if (srcEnc == getDefaultBlockedEncoding(context, srcShape, numWarps,
threadsPerWarp, numCTAs)) {
dstEnc = getDefaultBlockedEncoding(context, dstShape, numWarps,
threadsPerWarp, numCTAs);
return success();
}
if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) ||
!all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) {
return failure();
}
auto checkDivisibility = [&](StringRef name, ArrayRef<unsigned> subblock) {
for (int dim = 0; dim < srcShape.size(); dim++) {
if (srcShape[dim] >= subblock[dim] &&
srcShape[dim] % subblock[dim] != 0) {
return failure();
}
}
return success();
};
if (!succeeded(
checkDivisibility("sizePerThread", src.getSizePerThread())) ||
!succeeded(
checkDivisibility("threadsPerWarp", src.getThreadsPerWarp())) ||
!succeeded(checkDivisibility("warpsPerCTA", src.getWarpsPerCTA()))) {
return failure();
}
SmallVector<std::pair<SmallVector<int64_t>, SmallVector<int64_t>>> decomp =
getReshapeDecomposition(srcShape, dstShape);
auto srcInvOrder = inversePermutation(src.getOrder());
for (const auto &[srcDims, dstDims] : decomp) {
if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) {
return failure();
}
}
for (const auto &[srcDims, dstDims] : decomp) {
auto shapeRemaining = gather(srcShape, srcDims);
auto checkSubblock = [&, srcDims = srcDims](ArrayRef<unsigned> subblock) {
for (int i = srcDims.size() - 1; i >= 0; i--) {
int dim = srcDims[i];
if (subblock[dim] == 1) {
continue;
}
for (int j = i + 1; j < srcDims.size(); j++) {
if (shapeRemaining[j] != 1) {
return failure();
}
}
if (shapeRemaining[i] >= subblock[dim]) {
assert(shapeRemaining[i] % subblock[dim] == 0);
shapeRemaining[i] /= subblock[dim];
} else {
shapeRemaining[i] = 0;
}
if (shapeRemaining[i] == 0 && i != 0) {
return failure();
}
}
return success();
};
if (!succeeded(checkSubblock(src.getSizePerThread())) ||
!succeeded(checkSubblock(src.getThreadsPerWarp())) ||
!succeeded(checkSubblock(src.getWarpsPerCTA()))) {
return failure();
}
}
SmallVector<int64_t> dstShapeRemaining(dstShape);
auto computeSubblockSize = [&](ArrayRef<unsigned> srcSubblock,
SmallVector<unsigned> &dstSubblock,
StringRef fieldName) -> LogicalResult {
dstSubblock.resize(dstShape.size());
for (const auto &[srcDims, dstDims] : decomp) {
int64_t subblockRemaining = product(gather(srcSubblock, srcDims));
for (int i = dstDims.size() - 1; i >= 0; i--) {
auto &val = dstSubblock[dstDims[i]];
auto &shapeRemaining = dstShapeRemaining[dstDims[i]];
val = std::min(subblockRemaining, shapeRemaining);
assert(shapeRemaining % val == 0);
subblockRemaining /= val;
shapeRemaining /= val;
}
dstSubblock[dstDims[0]] *= subblockRemaining;
}
return success();
};
SmallVector<unsigned> dstSizePerThread;
SmallVector<unsigned> dstThreadsPerWarp;
SmallVector<unsigned> dstWarpsPerCTA;
if (!succeeded(computeSubblockSize(src.getSizePerThread(), dstSizePerThread,
"sizePerThread")) ||
!succeeded(computeSubblockSize(src.getThreadsPerWarp(),
dstThreadsPerWarp, "threadsPerWarp")) ||
!succeeded(computeSubblockSize(src.getWarpsPerCTA(), dstWarpsPerCTA,
"warpsPerCTA"))) {
return failure();
}
llvm::sort(decomp, [&](const auto &a, const auto &b) {
const auto &[srcDimsA, dstDimsA] = a;
const auto &[srcDimsB, dstDimsB] = b;
return srcInvOrder[srcDimsA.front()] < srcInvOrder[srcDimsB.front()];
});
SmallVector<unsigned> dstInvOrder(dstShape.size());
int i = 0;
for (const auto &[srcDims, dstDims] : decomp) {
for (auto dim : reverse(dstDims)) {
dstInvOrder[dim] = i++;
}
}
auto dstOrder = inversePermutation(dstInvOrder);
auto CTALayout = CTALayoutAttr::get(
src.getContext(),
SmallVector<unsigned>(dstShape.size(), 1),
SmallVector<unsigned>(dstShape.size(), 1),
llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
dstEnc = BlockedEncodingAttr::get(src.getContext(), dstSizePerThread,
dstThreadsPerWarp, dstWarpsPerCTA,
dstOrder, CTALayout);
return success();
}
LogicalResult
verifyLayoutsAreEqual(ArrayRef<int64_t> shape, Attribute expected,
Attribute got,
std::optional<Location> loc) const override {
if (expected == got) {
return success();
}
if (!expected || !got)
return failure();
if (!areLayoutsEquivalent(shape, cast<DistributedEncodingTrait>(expected),
cast<DistributedEncodingTrait>(got))) {
return emitOptionalError(loc, "Expected result encoding ", expected,
" but was ", got);
}
return success();
}
LogicalResult
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const override {
if (product(srcShape) != product(dstShape)) {
return emitOptionalError(loc, "numel of dst shape does not match "
"numel of src shape");
}
auto result =
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
if (succeeded(result)) {
return result;
}
if (!isa<DistributedEncodingTrait>(srcEnc)) {
return emitOptionalError(loc,
"Failed MemDescReshapeOp encoding inference");
}
auto ctx = srcEnc.getContext();
auto fp32Type = IntegerType::get(ctx, 32, IntegerType::Unsigned);
auto srcTy = RankedTensorType::get(srcShape, fp32Type, srcEnc);
LinearLayout ll =
inferReshapeLinearLayout(cast<TensorOrMemDesc>(srcTy), dstShape);
dstEnc = LinearEncodingAttr::get(srcEnc.getContext(), ll);
return success();
}
LogicalResult
inferDefaultJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
ArrayRef<int64_t> shape,
std::optional<Location> loc) const override {
auto ctx = getContext();
if (auto enc = mlir::dyn_cast<SliceEncodingAttr>(srcEnc);
enc && enc.getDim() == shape.size()) {
SmallVector<int64_t> joinedShape(shape);
joinedShape.push_back(2);
auto parent = enc.getParent();
auto parentLL = toLinearLayout(joinedShape, parent);
Attribute splitEnc;
auto result = inferSplitOpEncoding(parent, splitEnc, joinedShape, loc);
if (succeeded(result) &&
areLayoutsEquivalent(shape, cast<DistributedEncodingTrait>(splitEnc),
cast<DistributedEncodingTrait>(srcEnc))) {
dstEnc = parent;
return success();
}
} else if (auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc)) {
auto append = [](ArrayRef<unsigned> vals, int val) {
SmallVector<unsigned> ret(vals);
ret.push_back(val);
return ret;
};
auto appendMajorDim = [](ArrayRef<unsigned> order) {
SmallVector<unsigned> ret(order);
ret.insert(ret.begin(), ret.size());
return ret;
};
dstEnc = BlockedEncodingAttr::get(
enc.getContext(), append(enc.getSizePerThread(), 2),
append(enc.getThreadsPerWarp(), 1), append(enc.getWarpsPerCTA(), 1),
appendMajorDim(enc.getOrder()),
CTALayoutAttr::get(enc.getContext(), append(enc.getCTAsPerCGA(), 1),
append(enc.getCTASplitNum(), 1),
appendMajorDim(enc.getCTAOrder())));
return success();
}
auto ll = toLinearLayout(shape, srcEnc);
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
dstShape.push_back(1);
ll = ll.reshapeOuts(standardOutDimPairs(ctx, dstShape));
auto axis = dstShape.size() - 1;
auto newLl = LinearLayout::empty();
auto result =
tryJoinOnAxis(ctx, ll, newLl, true, axis, loc);
assert(result.succeeded());
dstEnc = LinearEncodingAttr::get(ctx, newLl);
return success();
}
LogicalResult
inferSplitOpEncoding(Attribute srcEnc, Attribute &dstEnc,
ArrayRef<int64_t> shape,
std::optional<Location> loc) const override {
auto enc = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
bool isSimpleSplit = (enc && (enc.getSizePerThread().back() == 2) &&
(enc.getThreadsPerWarp().back() == 1) &&
(enc.getWarpsPerCTA().back() == 1) &&
(enc.getCTAsPerCGA().back() == 1));
if (isSimpleSplit) {
SmallVector<unsigned> newOrder(enc.getOrder());
int splitDim = newOrder.size() - 1;
newOrder.erase(std::remove(newOrder.begin(), newOrder.end(), splitDim),
newOrder.end());
dstEnc = BlockedEncodingAttr::get(
enc.getContext(),
ArrayRef(enc.getSizePerThread()).drop_back(1),
ArrayRef(enc.getThreadsPerWarp()).drop_back(1),
ArrayRef(enc.getWarpsPerCTA()).drop_back(1), ArrayRef(newOrder),
CTALayoutAttr::get(enc.getContext(),
ArrayRef(enc.getCTAsPerCGA()).drop_back(1),
ArrayRef(enc.getCTASplitNum()).drop_back(1),
ArrayRef(enc.getCTAOrder()).drop_front(1)));
return success();
}
auto axis = shape.size() - 1;
if (shape[axis] != 2) {
return emitOptionalError(
loc, "SplitOp input shape should have 2 in the last dim");
}
auto ctx = getContext();
auto ll = toLinearLayout(shape, srcEnc);
auto newLl = LinearLayout::empty();
auto result =
tryJoinOnAxis(ctx, ll, newLl, false, axis, loc);
if (!result.succeeded()) {
return failure();
}
SmallVector<int64_t> dstShape(shape.begin(), shape.end());
dstShape.pop_back();
newLl = newLl.reshapeOuts(standardOutDimPairs(ctx, dstShape));
dstEnc = LinearEncodingAttr::get(ctx, newLl);
return success();
}
LogicalResult
inferFp4ToFpOpEncoding(ArrayRef<int64_t> shape, int axis, Attribute inEnc,
Attribute &outEnc, bool fwdInference,
std::optional<Location> loc) const override {
auto *ctx = getContext();
if (getOrder(cast<DistributedEncodingTrait>(inEnc), shape)[axis] == 0) {
if (auto dotEnc = mlir::dyn_cast<DotOperandEncodingAttr>(inEnc)) {
auto kWidth = dotEnc.getKWidth();
if (fwdInference) {
kWidth *= 2;
} else {
if (kWidth > 1) {
kWidth /= 2;
} else {
return emitOptionalError(loc,
"Fp4ToFpOp requires at least 2 elements "
"per thread in the axis dimension");
}
}
outEnc = DotOperandEncodingAttr::get(ctx, dotEnc.getOpIdx(),
dotEnc.getParent(), kWidth);
return success();
}
if (auto blockedEnc = mlir::dyn_cast<BlockedEncodingAttr>(inEnc)) {
auto sizePerThread = llvm::to_vector(blockedEnc.getSizePerThread());
if (fwdInference) {
sizePerThread[axis] *= 2;
} else {
if (sizePerThread[axis] > 1) {
sizePerThread[axis] /= 2;
} else {
return emitOptionalError(
loc, "Fp4ToFpOp requires at least 2 elements per "
"thread in the axis dimension");
}
}
outEnc = BlockedEncodingAttr::get(
ctx, sizePerThread, blockedEnc.getThreadsPerWarp(),
blockedEnc.getWarpsPerCTA(), blockedEnc.getOrder(),
blockedEnc.getCTALayout());
return success();
}
}
auto ll = toLinearLayout(shape, inEnc);
auto newLl = LinearLayout::empty();
auto result = tryJoinOnAxis(ctx, ll, newLl, fwdInference, axis, loc);
if (!result.succeeded())
return result;
outEnc = LinearEncodingAttr::get(ctx, newLl);
return success();
}
};
struct TritonGPUVerifyTensorLayoutInterface
: public triton::DialectVerifyTensorLayoutInterface {
using DialectVerifyTensorLayoutInterface::DialectVerifyTensorLayoutInterface;
LogicalResult verifyTensorLayout(
Attribute layout, RankedTensorType rankedTy, Operation *op,
function_ref<InFlightDiagnostic()> makeErr) const override {
auto distr = dyn_cast<triton::gpu::DistributedEncodingTrait>(layout);
if (!distr)
return makeErr()
<< "Non-distributed layout is not allowed in tensor type.";
auto rank = distr.getRepOrder().size();
if (rank != rankedTy.getRank())
return makeErr() << "Layout has rank " << rank
<< ", but the tensor it's attached to has rank "
<< rankedTy.getRank() << ".";
if (llvm::any_of(rankedTy.getShape(),
[](int64_t i) { return !llvm::isPowerOf2_64(i); })) {
return makeErr() << "Layout has shape " << rankedTy.getShape()
<< ", but the tensor it's attached to has shape "
<< rankedTy.getShape()
<< " which is not a power of two.";
}
auto ll = toLinearLayout(rankedTy);
ModuleOp module = op->getParentOfType<ModuleOp>();
auto kLane = StringAttr::get(module.getContext(), "lane");
int moduleThreadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
if (ll.getInDimSize(kLane) != moduleThreadsPerWarp) {
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kLane)
<< " threads per warp, but the module specifies "
<< moduleThreadsPerWarp << " threads per warp.";
}
std::optional<int> moduleWarpsPerCTA = maybeLookupNumWarps(op);
if (!moduleWarpsPerCTA) {
return makeErr()
<< "Could not determine the number of warps per CTA. Operation "
"is not in a context with `ttg.num-warps`.";
}
auto kWarp = StringAttr::get(module.getContext(), "warp");
if (ll.getInDimSize(kWarp) != *moduleWarpsPerCTA) {
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kWarp)
<< " warps per CTA, but the context requires "
<< *moduleWarpsPerCTA << " warps per CTA.";
}
auto kBlock = StringAttr::get(module.getContext(), "block");
int moduleCTAsPerCGA = TritonGPUDialect::getNumCTAs(module);
if (ll.getInDimSize(kBlock) != moduleCTAsPerCGA) {
return makeErr() << layout << ".\nLayout has " << ll.getInDimSize(kBlock)
<< " CTAs per CGA, but the context requires "
<< moduleCTAsPerCGA << " CTAs per CGA.";
}
return success();
}
};
static SmallVector<int64_t> delinearizeIndex(int64_t idx,
ArrayRef<int64_t> shape) {
SmallVector<int64_t> ret(shape.size());
for (int i = shape.size() - 1; i >= 0; i--) {
ret[i] = idx % shape[i];
idx /= shape[i];
}
return ret;
}
static int numCharacterPadding(int value, int max) {
return std::to_string(max).size() - std::to_string(value).size();
}
static std::string paddedString(int value, int max) {
int nbChar = numCharacterPadding(value, max);
std::string str;
for (int i = 0; i < nbChar; i++)
str += " ";
str += std::to_string(value);
return str;
}
std::string getSharedLayoutStr(RankedTensorType type, bool useHWPointOfView) {
if (!type)
return "";
auto shape = type.getShape();
auto layout = type.getEncoding();
LinearLayout ll = triton::gpu::toLinearLayout(shape, layout);
StringAttr kOffset = StringAttr::get(type.getContext(), "offset");
StringAttr kBlock = StringAttr::get(type.getContext(), "block");
int64_t tensorSize = product(type.getShape());
auto enc = type.getEncoding();
unsigned numBlocks = getNumCTAs(enc);
int32_t blockSize = tensorSize / numBlocks;
std::vector<std::string> elementMapping(tensorSize);
std::vector<std::string> offsetMapping;
int32_t idx = 0;
for (int32_t block = 0; block < numBlocks; block++) {
for (int32_t offset = 0; offset < blockSize; offset++) {
SmallVector<std::pair<StringAttr, int32_t>> inputs = {
{kBlock, block},
{kOffset, offset},
};
SmallVector<std::pair<StringAttr, int32_t>> outputs = ll.apply(inputs);
std::string sharedInfo = "(";
std::string &value = elementMapping[idx];
if (!value.empty())
value += "|";
value += "(";
for (int i = 0; i < outputs.size(); i++) {
if (i > 0) {
sharedInfo += ",";
value += ":";
}
auto index = paddedString(outputs[i].second, shape[i]);
sharedInfo += index;
value += index;
}
value += ")";
sharedInfo += ")";
offsetMapping.push_back(sharedInfo);
idx++;
}
}
std::string layoutStr;
if (!useHWPointOfView) {
int rank = type.getRank();
bool newLine = true;
for (int i = 0; i < tensorSize; i++) {
auto indices = delinearizeIndex(i, shape);
int numOpenBracket = 0;
for (int j = rank - 1; j >= 0; j--) {
if (indices[j] % shape[j] != 0)
break;
layoutStr += "[";
numOpenBracket++;
}
if (newLine) {
for (int j = 0; j < rank - numOpenBracket; j++)
layoutStr += " ";
newLine = false;
}
layoutStr += elementMapping[i];
auto nextIndices = delinearizeIndex(i + 1, shape);
for (int j = rank - 1; j >= 0; j--) {
if (nextIndices[j] % shape[j] != 0)
break;
layoutStr += "]";
}
if (nextIndices.back() % shape.back() == 0) {
layoutStr += "\n";
newLine = true;
} else {
layoutStr += ",";
}
}
} else {
uint32_t idx = 0;
for (int32_t block = 0; block < numBlocks; block++) {
layoutStr += "Block: " + std::to_string(block) + ":\n";
for (int32_t offset = 0; offset < (tensorSize / numBlocks); offset++) {
layoutStr += "Offset: " + std::to_string(offset) + " -> ";
layoutStr += offsetMapping[idx];
layoutStr += "\n";
idx++;
}
}
}
return layoutStr;
}
std::string getDistributedLayoutStr(RankedTensorType tensorType,
bool useHWPointOfView) {
auto layout = tensorType.getEncoding();
if (!layout)
return "";
StringAttr kRegister = StringAttr::get(tensorType.getContext(), "register");
StringAttr kLane = StringAttr::get(tensorType.getContext(), "lane");
StringAttr kWarp = StringAttr::get(tensorType.getContext(), "warp");
StringAttr kBlock = StringAttr::get(tensorType.getContext(), "block");
LinearLayout ll = toLinearLayout(tensorType);
int64_t tensorSize = product(tensorType.getShape());
std::vector<std::string> elementMapping(tensorSize);
std::vector<std::string> threadMapping;
unsigned threadsPerWarp = ll.getInDimSize(kLane);
unsigned numWarpsPerCTA = ll.getInDimSize(kWarp);
unsigned numBlocks = ll.getInDimSize(kBlock);
int numElementsPerThreads = ll.getInDimSize(kRegister);
for (int blockId = 0; blockId < numBlocks; ++blockId) {
for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) {
for (int tid = 0; tid < threadsPerWarp; ++tid) {
for (int idx = 0; idx < numElementsPerThreads; ++idx) {
SmallVector<std::pair<StringAttr, int32_t>> inputs = {
{kBlock, blockId},
{kWarp, warpId},
{kLane, tid},
{kRegister, idx}};
SmallVector<std::pair<StringAttr, int32_t>> outputs =
ll.apply(inputs);
int32_t linearizedIdx = 0;
int stride = 1;
for (int i = outputs.size() - 1; i >= 0; i--) {
linearizedIdx += outputs[i].second * stride;
stride *= tensorType.getDimSize(i);
}
std::string &value = elementMapping[linearizedIdx];
if (!value.empty())
value += "|";
int padding = numCharacterPadding(blockId, numBlocks) +
numCharacterPadding(tid + warpId * threadsPerWarp,
numWarpsPerCTA * threadsPerWarp) +
numCharacterPadding(idx, numElementsPerThreads);
for (int i = 0; i < padding; i++)
value += " ";
if (numBlocks > 1)
value += "B" + std::to_string(blockId) + ":";
value += "T" + std::to_string(tid + warpId * threadsPerWarp) + ":" +
std::to_string(idx);
std::string threadInfo = "(";
for (int i = 0; i < outputs.size(); i++) {
if (i > 0)
threadInfo += ",";
threadInfo +=
paddedString(outputs[i].second, tensorType.getDimSize(i));
}
threadInfo += ")";
threadMapping.push_back(threadInfo);
}
}
}
}
std::string layoutStr;
if (!useHWPointOfView) {
int rank = tensorType.getRank();
bool newLine = true;
for (int i = 0; i < tensorSize; i++) {
auto indices = delinearizeIndex(i, tensorType.getShape());
int numOpenBracket = 0;
for (int j = rank - 1; j >= 0; j--) {
if (indices[j] % tensorType.getDimSize(j) != 0)
break;
layoutStr += "[";
numOpenBracket++;
}
if (newLine) {
for (int j = 0; j < rank - numOpenBracket; j++)
layoutStr += " ";
newLine = false;
}
layoutStr += elementMapping[i];
auto nextIndices = delinearizeIndex(i + 1, tensorType.getShape());
for (int j = rank - 1; j >= 0; j--) {
if (nextIndices[j] % tensorType.getDimSize(j) != 0)
break;
layoutStr += "]";
}
if (nextIndices.back() % tensorType.getShape().back() == 0) {
layoutStr += "\n";
newLine = true;
} else {
layoutStr += ", ";
}
}
} else {
for (int blockId = 0; blockId < numBlocks; blockId++) {
if (numBlocks > 1)
layoutStr += "Block" + std::to_string(blockId) + ":\n";
for (int warpId = 0; warpId < numWarpsPerCTA; warpId++) {
layoutStr += "Warp" + std::to_string(warpId) + ":\n";
for (int idx = 0; idx < numElementsPerThreads; ++idx) {
for (int tid = 0; tid < threadsPerWarp; ++tid) {
int linearizedIdx =
blockId * numWarpsPerCTA * threadsPerWarp *
numElementsPerThreads +
warpId * threadsPerWarp * numElementsPerThreads +
tid * numElementsPerThreads + idx;
layoutStr += threadMapping[linearizedIdx];
if (tid < threadsPerWarp - 1)
layoutStr += ", ";
}
layoutStr += "\n";
}
}
}
}
return layoutStr;
}
template <typename T>
llvm::SmallVector<T>
mlir::triton::gpu::expandMatrixShapeWithBatch(llvm::ArrayRef<T> s) {
auto rank = s.size();
assert(rank == 2 || rank == 3);
if (rank == 3)
return llvm::SmallVector<T>(s);
return {1, s[0], s[1]};
}
template llvm::SmallVector<int64_t>
mlir::triton::gpu::expandMatrixShapeWithBatch<int64_t>(
llvm::ArrayRef<int64_t> s);
template llvm::SmallVector<unsigned>
mlir::triton::gpu::expandMatrixShapeWithBatch<unsigned>(
llvm::ArrayRef<unsigned> s);
llvm::SmallVector<unsigned>
mlir::triton::gpu::expandMatrixOrderWithBatch(llvm::ArrayRef<unsigned> o) {
int rank = o.size();
assert(rank == 2 || rank == 3);
if (rank == 3)
return llvm::SmallVector<unsigned>(o);
llvm::SmallVector<unsigned> expanded(3, 0);
for (int i = 0; i < rank; ++i)
expanded[i] += o[i] + 1;
return expanded;
}
std::string mlir::triton::gpu::getLayoutStr(RankedTensorType tensorType,
bool useHWPointOfView) {
auto layout = tensorType.getEncoding();
if (mlir::isa<SharedEncodingTrait>(layout)) {
return getSharedLayoutStr(tensorType, useHWPointOfView);
} else if (mlir::isa<DistributedEncodingTrait>(layout)) {
return getDistributedLayoutStr(tensorType, useHWPointOfView);
}
llvm::report_fatal_error("Unimplemented usage of getLayoutStr");
return "";
}
void mlir::triton::gpu::dumpLayout(RankedTensorType tensorType) {
llvm::errs() << getLayoutStr(tensorType, false);
}
void mlir::triton::gpu::dumpHWLayout(RankedTensorType tensorType) {
llvm::errs() << getLayoutStr(tensorType, true);
}
namespace {
struct TensorModel
: public triton::gpu::TensorOrMemDesc::ExternalModel<TensorModel,
RankedTensorType> {
Type getElementType(Type pointer) const {
return cast<RankedTensorType>(pointer).getElementType();
}
Attribute getEncoding(Type pointer) const {
return cast<RankedTensorType>(pointer).getEncoding();
}
ArrayRef<int64_t> getShape(Type pointer) const {
return cast<RankedTensorType>(pointer).getShape();
}
int64_t getRank(Type pointer) const {
return cast<RankedTensorType>(pointer).getRank();
}
int64_t getElementTypeBitWidth(Type pointer) const {
return cast<RankedTensorType>(pointer).getElementTypeBitWidth();
}
};
struct MemDescModel
: public triton::gpu::TensorOrMemDesc::ExternalModel<MemDescModel,
MemDescType> {
Type getElementType(Type pointer) const {
return cast<MemDescType>(pointer).getElementType();
}
Attribute getEncoding(Type pointer) const {
return cast<MemDescType>(pointer).getEncoding();
}
ArrayRef<int64_t> getShape(Type pointer) const {
return cast<MemDescType>(pointer).getShape();
}
int64_t getRank(Type pointer) const {
return cast<MemDescType>(pointer).getShape().size();
}
int64_t getElementTypeBitWidth(Type pointer) const {
return cast<MemDescType>(pointer).getElementType().getIntOrFloatBitWidth();
}
};
}
void TritonGPUDialect::initialize() {
registerTypes();
addAttributes<
#define GET_ATTRDEF_LIST
#include "triton/Dialect/TritonGPU/IR/AttrDefs.cpp.inc"
>();
addOperations<
#define GET_OP_LIST
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
#include "triton/Dialect/TritonGPU/IR/OpsEnums.cpp.inc"
>();
addInterfaces<TritonInlinerInterface>();
addInterfaces<TritonGPUOpAsmInterface>();
addInterfaces<TritonGPUInferLayoutInterface>();
addInterfaces<TritonGPUVerifyTensorLayoutInterface>();
RankedTensorType::attachInterface<TensorModel>(*getContext());
MemDescType::attachInterface<MemDescModel>(*getContext());
}
LogicalResult TritonGPUDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attr) {
if (llvm::is_contained(
{AttrNumCTAsName, AttrTargetName, AttrNumThreadsPerWarp},
attr.getName()) &&
!isa<ModuleOp>(op)) {
return op->emitOpError("has unexpected attribute ")
<< attr.getName() << " which is expected only on `module` ops";
}
if (attr.getName() == AttrNumWarpsName && !isa<ModuleOp, FuncOp>(op)) {
return op->emitOpError("has unexpected attribute ")
<< attr.getName()
<< " which is expected only on `module` or `tt.func` ops";
}
return success();
}
int TritonGPUDialect::getNumCTAs(ModuleOp module) {
if (auto attr = module->getAttrOfType<IntegerAttr>(AttrNumCTAsName))
return attr.getInt();
return 1;
}
int TritonGPUDialect::getThreadsPerWarp(ModuleOp module) {
if (auto attr = module->getAttrOfType<IntegerAttr>(AttrNumThreadsPerWarp))
return attr.getInt();
return 32;
}
std::optional<int> triton::gpu::maybeLookupNumWarps(Operation *op) {
if (isa<ModuleOp, FuncOp>(op)) {
if (auto attr = op->getAttrOfType<IntegerAttr>(AttrNumWarpsName))
return attr.getInt();
} else if (auto partitions =
dyn_cast<WarpSpecializePartitionsOp>(op->getParentOp())) {
unsigned idx = op->getParentRegion()->getRegionNumber();
return partitions.getParentOp().getPartitionNumWarps()[idx];
}
if (Operation *parent = op->getParentOp())
return maybeLookupNumWarps(parent);
return {};
}
int triton::gpu::lookupNumWarps(Operation *op) {
std::optional<int> numWarps = maybeLookupNumWarps(op);
if (!numWarps) {
op->emitOpError(
"is not contained within a context that specifies the number of warps");
llvm::report_fatal_error("failed to lookup the number of warps, the "
"surrounding module should contain a " +
Twine(AttrNumWarpsName) + " attribute");
}
return *numWarps;
}
int triton::gpu::lookupThreadsPerWarp(OpBuilder &rewriter) {
assert(rewriter.getInsertionBlock() && "expected an insertion point");
Operation *op =
rewriter.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
assert(op && "cannot check threads per warp outside of module");
return triton::gpu::TritonGPUDialect::getThreadsPerWarp(cast<ModuleOp>(op));
}
int triton::gpu::lookupNumCTAs(OpBuilder &rewriter) {
assert(rewriter.getInsertionBlock() && "expected an insertion point");
Operation *op =
rewriter.getInsertionBlock()->getParentOp()->getParentOfType<ModuleOp>();
assert(op && "cannot check number of CTAs outside of module");
return triton::gpu::TritonGPUDialect::getNumCTAs(cast<ModuleOp>(op));
}
bool triton::gpu::areLayoutsEquivalent(ArrayRef<int64_t> shape,
DistributedEncodingTrait lhs,
DistributedEncodingTrait rhs) {
auto lhsLL = triton::gpu::toLinearLayout(shape, lhs);
auto rhsLL = triton::gpu::toLinearLayout(shape, rhs);
return lhsLL == rhsLL;
}
bool triton::gpu::isInnermostContiguous(MemDescType type, unsigned numElems) {
ArrayRef<int64_t> shape = type.getShape();
Attribute enc = type.getEncoding();
MLIRContext *ctx = enc.getContext();
LinearLayout actual = toLinearLayout(type);
StringAttr fastestIn = *actual.getInDimNames().begin();
auto outNames = actual.getOutDimNames();
SmallVector<StringAttr> revOut(outNames.begin(), outNames.end());
std::reverse(revOut.begin(), revOut.end());
actual = actual.transposeOuts(revOut).flattenOuts();
return actual.getNumConsecutiveInOut() >= numElems;
}
LinearLayout triton::gpu::inferReshapeLinearLayout(TensorOrMemDesc srcTy,
ArrayRef<int64_t> dstShape) {
auto *ctx = srcTy.getContext();
auto src = toLinearLayout(srcTy);
assert(product(srcTy.getShape()) == product(dstShape));
auto dst = reshapeLayout(ctx, src, dstShape);
return dst;
}