#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/Support/DebugStringHelper.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonGPU/IR/Types.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Tools/LayoutUtils.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/LogicalResult.h"
static mlir::ParseResult parseOffsets(mlir::OpAsmParser &p,
mlir::DenseI32ArrayAttr &attr) {
llvm::SmallVector<int32_t> values;
if (p.parseCommaSeparatedList([&]() {
int32_t v;
if (p.parseInteger(v))
return mlir::failure();
values.push_back(v);
return mlir::success();
}))
return mlir::failure();
attr = p.getBuilder().getDenseI32ArrayAttr(values);
return mlir::success();
}
static void printOffsets(mlir::OpAsmPrinter &p, mlir::Operation *op,
mlir::DenseI32ArrayAttr attr) {
auto vals = attr.asArrayRef();
llvm::interleaveComma(vals, p, [&](int32_t v) { p << v; });
}
#define GET_OP_CLASSES
#include "triton/Dialect/TritonGPU/IR/Ops.cpp.inc"
namespace mlir::triton::gpu {
namespace {
template <typename T> bool hasEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = dyn_cast<TensorOrMemDesc>(type)) {
auto encoding = tensorType.getEncoding();
return encoding && isa<T>(encoding);
}
return false;
}
bool hasDotOperandEncoding(Value value) {
return hasEncoding<triton::gpu::DotOperandEncodingAttr>(value);
}
bool isConvertTrivial(ConvertLayoutOp op) {
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
auto srcEncoding = srcType.getEncoding();
auto dstEncoding = dstType.getEncoding();
return cast<DialectInferLayoutInterface>(&srcEncoding.getDialect())
->verifyLayoutsAreEqual(srcType.getShape(), srcEncoding, dstEncoding, {})
.succeeded();
}
}
struct CanonicalizeConvertFromReshape
: public mlir::OpRewritePattern<triton::ReshapeOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(triton::ReshapeOp op,
PatternRewriter &rewriter) const override {
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
if (isConvertTrivial(convert)) {
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
op, op.getType(), convert.getSrc(), op.getAllowReorder(),
op.getEfficientLayout());
return success();
}
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
return failure();
if (!op.getAllowReorder())
return failure();
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
op, op.getType(), convert.getSrc(), op.getAllowReorder(),
op.getEfficientLayout());
return mlir::success();
}
};
struct CanonicalizeConvertFromTranspose
: public mlir::OpRewritePattern<triton::TransOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(triton::TransOp op,
PatternRewriter &rewriter) const override {
if (isIota(op.getOrder())) {
rewriter.replaceOpWithNewOp<ConvertLayoutOp>(op, op.getType(),
op.getSrc());
return success();
}
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert || !isConvertTrivial(convert))
return failure();
rewriter.replaceOpWithNewOp<triton::TransOp>(
op, op.getType(), convert.getSrc(), op.getOrder());
return success();
}
};
struct CanonicalizeConvertFromHistogram
: public mlir::OpRewritePattern<triton::HistogramOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(triton::HistogramOp op,
PatternRewriter &rewriter) const override {
auto src = op.getSrc();
auto convert = src.getDefiningOp<ConvertLayoutOp>();
if (!convert) {
return failure();
}
src = convert.getSrc();
auto mask = op.getMask();
if (mask) {
auto sharedType = getI1SameShape(src.getType());
rewriter.setInsertionPoint(op);
mask = rewriter.create<ConvertLayoutOp>(op.getLoc(), sharedType, mask);
}
rewriter.replaceOpWithNewOp<triton::HistogramOp>(
op, op->getResult(0).getType(), src, mask);
return success();
}
};
struct CanonicalizeConvertFromGatherSource : public OpRewritePattern<GatherOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(GatherOp op, PatternRewriter &rewriter) const override {
if (op.getEfficientLayout())
return failure();
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
rewriter.replaceOpWithNewOp<GatherOp>(op, convert.getSrc(), op.getIndices(),
op.getAxis());
return success();
}
};
struct CanonicalizeConvertFromAlloc
: public mlir::OpRewritePattern<triton::gpu::LocalAllocOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(triton::gpu::LocalAllocOp op,
PatternRewriter &rewriter) const override {
if (!op.getSrc())
return failure();
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
op, op->getResult(0).getType(), convert.getSrc());
return mlir::success();
}
};
struct CanonicalizeConvertFromLocalStore
: public mlir::OpRewritePattern<triton::gpu::LocalStoreOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(triton::gpu::LocalStoreOp op,
PatternRewriter &rewriter) const override {
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
rewriter.replaceOpWithNewOp<triton::gpu::LocalStoreOp>(op, convert.getSrc(),
op.getDst());
return mlir::success();
}
};
struct CanonicalizeConvertFromSplit
: public mlir::OpRewritePattern<triton::SplitOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(triton::SplitOp op,
PatternRewriter &rewriter) const override {
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
auto srcEncoding = convert.getSrc().getType().getEncoding();
auto dstEncoding = inferDstEncoding(op, srcEncoding);
if (dstEncoding != op.getOutLHS().getType().getEncoding())
return failure();
rewriter.replaceOpWithNewOp<triton::SplitOp>(op, convert.getSrc());
return mlir::success();
}
};
struct CanonicalizeConvertFromConvert
: public OpRewritePattern<ConvertLayoutOp> {
using OpRewritePattern::OpRewritePattern;
mlir::LogicalResult
matchAndRewrite(ConvertLayoutOp op,
PatternRewriter &rewriter) const override {
if (op->getResultTypes() == op->getOperandTypes()) {
rewriter.replaceOp(op, op->getOperands());
return success();
}
auto srcType = op.getSrc().getType();
auto dstType = op.getType();
if (mlir::isa<DotOperandEncodingAttr>(dstType.getEncoding()) &&
mlir::isa<NvidiaMmaEncodingAttr>(srcType.getEncoding()))
return failure();
Operation *arg = op.getSrc().getDefiningOp();
if (!arg)
return failure();
if (auto reshape = dyn_cast<ReshapeOp>(arg)) {
if (!reshape.getAllowReorder() || reshape.getEfficientLayout() ||
isExpensiveView(reshape.getSrc().getType(), op.getType()))
return failure();
if (hasDotOperandEncoding(op->getOperand(0)) ||
hasDotOperandEncoding(op->getResult(0)))
return failure();
rewriter.replaceOpWithNewOp<ReshapeOp>(op, op->getResult(0).getType(),
reshape.getResult(),
reshape.getAllowReorder());
return success();
}
if (auto histogram = dyn_cast<HistogramOp>(arg)) {
rewriter.replaceOpWithNewOp<HistogramOp>(op, op->getResult(0).getType(),
histogram.getSrc(),
histogram.getMask());
return success();
}
if (auto sharedLoad = dyn_cast<LocalLoadOp>(arg)) {
rewriter.setInsertionPoint(arg);
rewriter.replaceOpWithNewOp<LocalLoadOp>(op, op->getResult(0).getType(),
sharedLoad.getSrc(),
sharedLoad.getToken());
return success();
}
if (auto cat = dyn_cast<CatOp>(arg)) {
if (isExpensiveCat(cat, op.getType().getEncoding()))
return failure();
rewriter.replaceOpWithNewOp<CatOp>(op, op->getResult(0).getType(),
cat.getOperands());
return success();
}
if (auto cvt = dyn_cast<ConvertLayoutOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::gpu::ConvertLayoutOp>(
op, op->getResultTypes().front(), cvt.getSrc());
return success();
}
if (auto splat = dyn_cast<triton::SplatOp>(arg)) {
rewriter.replaceOpWithNewOp<triton::SplatOp>(op, op->getResultTypes(),
splat.getSrc());
return success();
}
if (auto range = dyn_cast<MakeRangeOp>(arg)) {
rewriter.replaceOpWithNewOp<MakeRangeOp>(
op, op->getResultTypes(), range.getStart(), range.getEnd());
return success();
}
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
if (auto ret = dyn_cast<SplatElementsAttr>(cst.getValue())) {
auto ty = cast<ShapedType>(op->getResultTypes().front());
auto newRet =
SplatElementsAttr::get(ty, ret.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
return success();
}
return failure();
}
};
void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<CanonicalizeConvertFromConvert>(context);
patterns.add<CanonicalizeConvertFromReshape>(context);
patterns.add<CanonicalizeConvertFromTranspose>(context);
patterns.add<CanonicalizeConvertFromGatherSource>(context);
patterns.add<CanonicalizeConvertFromHistogram>(context);
patterns.add<CanonicalizeConvertFromAlloc>(context);
patterns.add<CanonicalizeConvertFromLocalStore>(context);
patterns.add<CanonicalizeConvertFromSplit>(context);
}
LogicalResult Fp4ToFpOp::verify() {
auto srcTy = cast<RankedTensorType>(getSrc().getType());
auto resTy = cast<RankedTensorType>(getResult().getType());
auto rank = srcTy.getRank();
if (rank != resTy.getRank())
return emitError() << "source rank " << rank << " != result rank "
<< resTy.getRank();
auto srcShape = srcTy.getShape();
auto resShape = resTy.getShape();
auto axis = getAxis();
if (!(0 <= axis && axis < rank))
return emitError() << "axis " << axis << " out of range for rank " << rank;
auto elemType = resTy.getElementType();
if (!(elemType.isBF16() || elemType.isF16()))
return emitError() << "only bf16 or f16 is supported for now, got "
<< elemType;
for (int i = 0; i < rank; ++i) {
if (i == axis) {
if (resShape[i] != srcShape[i] * 2)
return emitError() << "axis " << axis
<< " dimension must be 2x source dimension (src="
<< srcShape[i] << ", dst=" << resShape[i] << ")";
} else {
if (resShape[i] != srcShape[i])
return emitError() << "dimension " << i
<< " mismatch (src=" << srcShape[i]
<< ", dst=" << resShape[i] << ", axis=" << axis
<< ")";
}
}
return success();
}
void Fp4ToFpOp::build(OpBuilder &builder, OperationState &state,
TypedValue<RankedTensorType> src, Type elemType,
int32_t axis) {
auto srcTy = src.getType();
auto shape = llvm::to_vector(srcTy.getShape());
auto rank = srcTy.getRank();
assert(0 <= axis && axis < rank);
shape[axis] *= 2;
Attribute inEnc = srcTy.getEncoding();
Attribute outEnc;
auto result =
inEnc.getDialect()
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
->inferFp4ToFpOpEncoding(shape, axis, inEnc, outEnc,
true, state.location);
assert(succeeded(result));
auto resultTy = RankedTensorType::get(shape, elemType, outEnc);
build(builder, state, resultTy, src, axis);
}
OpFoldResult MemDescTransOp::fold(FoldAdaptor adaptor) {
if (isIota(getOrder())) {
return getSrc();
}
if (auto innerTrans = getSrc().getDefiningOp<MemDescTransOp>()) {
setOrder(applyPermutation(innerTrans.getOrder(), getOrder()));
setOperand(innerTrans.getSrc());
return getResult();
}
return {};
}
LogicalResult
MemDescTransOp::inferReturnTypes(MLIRContext *context,
std::optional<Location> loc,
MemDescTransOp::Adaptor adaptor,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto argTy = cast<MemDescType>(adaptor.getSrc().getType());
auto shape = argTy.getShape();
auto order = adaptor.getOrder();
SmallVector<int64_t> retShape = applyPermutation(shape, order);
auto retEltTy = argTy.getElementType();
Attribute argEncoding = argTy.getEncoding();
Attribute retEncoding;
if (argEncoding) {
Dialect &dialect = argEncoding.getDialect();
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
if (failed(inferLayoutInterface->inferTransOpEncoding(
argEncoding, shape, order, retEncoding, loc))) {
return failure();
}
}
SmallVector<int64_t> allocShape =
applyPermutation(argTy.getAllocShape().take_back(order.size()), order);
allocShape.insert(allocShape.begin(), argTy.getAllocShape().begin(),
argTy.getAllocShape().end() - order.size());
inferredReturnTypes.push_back(
MemDescType::get(retShape, retEltTy, retEncoding, argTy.getMemorySpace(),
argTy.getMutableMemory(), allocShape));
return success();
}
LogicalResult MemDescReshapeOp::verify() {
MemDescType dstType = getResult().getType();
MemDescType srcType = getSrc().getType();
if (product(dstType.getShape()) != product(srcType.getShape())) {
return emitError(
"number of src and dst elements of reshape must be the same");
}
if (dstType.getElementType() != srcType.getElementType()) {
return emitError("result element type must match src element type");
}
auto srcShape = srcType.getShape();
if (srcType.getAllocShape().take_back(srcShape.size()) != srcShape) {
return emitError("NYI: memdesc_reshape of memdesc_subslice");
}
MemDescType expectedTy;
if (failed(inferReturnTypes(getContext(), getLoc(), srcType,
dstType.getShape(), expectedTy)))
return failure();
if (expectedTy.getAllocShape() != dstType.getAllocShape()) {
return emitError(
"The result alloc shape does not match the expected alloc shape.");
}
if (expectedTy != dstType) {
return emitError("source and destination layout are incompatible.");
}
return success();
}
static LogicalResult inferMemDescReshapeOpEncoding(ArrayRef<int64_t> srcShape,
Attribute srcEnc,
ArrayRef<int64_t> dstShape,
Attribute &dstEnc) {
if (auto mmaEncoding = dyn_cast<NVMMASharedEncodingAttr>(srcEnc)) {
if (getNumCTAs(mmaEncoding) > 1)
return failure();
int innerDimDst =
mmaEncoding.getTransposed() ? dstShape.front() : dstShape.back();
int innerDimSrc =
mmaEncoding.getTransposed() ? srcShape.front() : srcShape.back();
if (innerDimDst != innerDimSrc)
return failure();
auto *ctx = srcEnc.getContext();
auto CTALayout = CTALayoutAttr::get(
ctx,
SmallVector<unsigned>(dstShape.size(), 1),
SmallVector<unsigned>(dstShape.size(), 1),
llvm::to_vector(llvm::seq<unsigned>(dstShape.size())));
dstEnc = NVMMASharedEncodingAttr::get(
ctx, mmaEncoding.getSwizzlingByteWidth(), mmaEncoding.getTransposed(),
mmaEncoding.getElementBitWidth(), mmaEncoding.getFp4Padded(),
CTALayout);
auto srcLL = toLinearLayout(srcShape, srcEnc);
auto dstLL = toLinearLayout(dstShape, dstEnc);
if (reshapeLayout(ctx, srcLL, dstShape) != dstLL) {
return failure();
}
return success();
}
return failure();
}
LogicalResult MemDescReshapeOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, MemDescType srcTy,
ArrayRef<int64_t> dstShape, MemDescType &inferredReturnType) {
if (product<int64_t>(dstShape) != product<int64_t>(srcTy.getShape()))
return emitOptionalError(
loc, "dst shape has different number of elements than src");
Attribute dstEncoding;
if (Attribute srcEnc = srcTy.getEncoding()) {
if (failed(inferMemDescReshapeOpEncoding(srcTy.getShape(), srcEnc, dstShape,
dstEncoding)))
return failure();
}
SmallVector<int64_t> dstAllocShape =
to_vector(srcTy.getAllocShape().take_front(srcTy.getAllocShape().size() -
srcTy.getShape().size()));
dstAllocShape.append(dstShape.begin(), dstShape.end());
inferredReturnType = MemDescType::get(
dstShape, srcTy.getElementType(), dstEncoding, srcTy.getMemorySpace(),
srcTy.getMutableMemory(), dstAllocShape);
return success();
}
OpFoldResult MemDescReinterpretOp::fold(FoldAdaptor adaptor) {
if (getType() == getSrc().getType())
return getSrc();
return {};
}
void LocalAllocOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
Operation *op = getOperation();
if (!getType().getMutableMemory() && !op->hasAttr("allocation.offset"))
return;
OpResult alloc = getOperation()->getOpResult(0);
effects.emplace_back(MemoryEffects::Allocate::get(), alloc,
SharedMemory::get());
if (getSrc())
effects.emplace_back(MemoryEffects::Write::get(), alloc,
SharedMemory::get());
}
OpFoldResult LocalAllocOp::fold(FoldAdaptor adaptor) {
if (getType().getMutableMemory())
return {};
auto src = getSrc();
if (!src)
return {};
auto localLoadOp = src.getDefiningOp<LocalLoadOp>();
if (!localLoadOp)
return {};
auto loadSrc = localLoadOp.getSrc();
if (loadSrc.getType() != getType())
return {};
return loadSrc;
}
int32_t LocalAllocOp::getAlignmentOrDefault() {
auto align = getAlignment();
if (align) {
return *align;
}
auto ty = getType();
auto enc = dyn_cast<SharedEncodingTrait>(ty.getEncoding());
return enc ? enc.getAlignment() : 16;
}
LogicalResult verifyMemoryOpTypes(Operation *op, ShapedType srcTy,
ShapedType dstTy) {
if (srcTy.getElementType() != dstTy.getElementType()) {
return op->emitOpError("source element type ")
<< srcTy << " must match "
<< "destination element type " << dstTy.getElementType();
}
if (srcTy.getShape() != dstTy.getShape()) {
return op->emitOpError("source shape [")
<< srcTy.getShape() << "] must match ["
<< "destination shape " << dstTy.getShape() << "]";
}
return success();
}
LogicalResult verifyAllocOp(Operation *op, Value src, MemDescType dstTy) {
if (dstTy.getShape() != dstTy.getAllocShape())
return op->emitOpError("result shape and its alloc shape must match");
if (!src) {
if (!dstTy.getMutableMemory()) {
return op->emitOpError(
"uninitialized alloc must have a mutable memdesc type");
}
return success();
}
return verifyMemoryOpTypes(op, cast<RankedTensorType>(src.getType()), dstTy);
}
static LogicalResult verifySharedMemoryRank(Operation *op,
RankedTensorType type,
MemDescType memdesc,
StringRef regName) {
auto enc = dyn_cast<LayoutEncodingTrait>(memdesc.getEncoding());
if (!enc)
return op->emitOpError("expected memdesc to have a shared memory encoding");
if (type.getRank() != enc.getRank()) {
return op->emitOpError(regName)
<< " has rank " << type.getRank()
<< " but memdesc encoding has rank " << enc.getRank();
}
return success();
}
LogicalResult LocalAllocOp::verify() {
if (!isa<SharedMemorySpaceAttr>(getType().getMemorySpace()))
return emitOpError("should create a buffer of shared memory");
if (getSrc() && failed(verifySharedMemoryRank(*this, getSrc().getType(),
getType(), "source")))
return failure();
return verifyAllocOp(*this, getSrc(), getType());
}
LogicalResult LocalStoreOp::verify() {
if (!getDst().getType().getMutableMemory())
return emitOpError("Cannot store into immutable memory");
if (failed(verifySharedMemoryRank(*this, getSrc().getType(),
getDst().getType(), "source")))
return failure();
return verifyMemoryOpTypes(*this, getSrc().getType(), getDst().getType());
}
LogicalResult LocalLoadOp::verify() {
if (failed(verifySharedMemoryRank(*this, getType(), getSrc().getType(),
"result")))
return failure();
return verifyMemoryOpTypes(*this, getSrc().getType(), getType());
}
LogicalResult AsyncCopyGlobalToLocalOp::verify() {
if (!getResult().getType().getMutableMemory())
return emitOpError("Cannot store into immutable memory");
return success();
}
LogicalResult MemDescIndexOp::verify() {
auto srcTy = getSrc().getType();
auto dstTy = getType();
if (srcTy.getElementType() != dstTy.getElementType()) {
return emitError("result element type must match desc element type");
}
bool correctRank = srcTy.getRank() == dstTy.getRank() + 1;
if (!correctRank) {
return emitError("result rank must be input rank - 1");
}
if (srcTy.getAllocShape().size() != srcTy.getRank()) {
return emitError("We don't allow taking memdesc_index of a memdesc_index");
}
if (ArrayRef(srcTy.getShape()).take_back(dstTy.getRank()) !=
dstTy.getShape()) {
return emitError("result shape must equal to srcShape[1:]");
}
bool isSubview = srcTy.getAllocShape() != srcTy.getShape();
if (isSubview) {
return emitError("We don't support memdesc_index of a subview");
}
auto srcEnc = srcTy.getEncoding();
auto dstEnc = dstTy.getEncoding();
if (bool(srcEnc) != bool(dstEnc)) {
return emitError("src and result must both have or not have an encoding");
}
if (isa<SharedEncodingTrait>(srcEnc) != isa<SharedEncodingTrait>(dstEnc)) {
return emitError("src and dst must have the same type of encoding");
}
if (isa<triton::nvidia_gpu::TensorMemoryEncodingAttr>(srcEnc)) {
if (srcTy.getRank() != 3 || dstTy.getRank() != 2) {
return emitError("only 3D -> 2D subviews are supported for "
"TensorMemoryEncodingAttr");
}
return success();
}
return success();
}
LogicalResult MemDescSubsliceOp::verify() {
auto srcTy = getSrc().getType();
auto dstTy = getType();
if (srcTy.getElementType() != dstTy.getElementType()) {
return emitError("result element type must match desc element type");
}
if (getOffsets().size() != srcTy.getRank()) {
return emitError("offsets must have the same rank as input");
}
if (srcTy.getRank() != dstTy.getRank()) {
return emitError("result rank must equal to input rank");
}
auto srcEnc = srcTy.getEncoding();
auto dstEnc = dstTy.getEncoding();
if (bool(srcEnc) != bool(dstEnc)) {
return emitError("src and result must both have or not have an encoding");
}
if (!isa<SharedEncodingTrait>(srcEnc) || !isa<SharedEncodingTrait>(dstEnc)) {
return emitError("src and dst must both be of shared memory encoding");
}
SetVector<int> splitDims{};
for (int i = 0; i < srcTy.getRank(); i++) {
if (srcTy.getDimSize(i) != dstTy.getDimSize(i)) {
splitDims.insert(i);
}
}
SmallVector<int64_t> offsets(getOffsets().begin(), getOffsets().end());
if (splitDims.empty()) {
return success();
}
for (auto [dim, offset] : llvm::enumerate(offsets)) {
if (!splitDims.contains(dim)) {
if (offset != 0) {
return emitError("A non zero offset found in a dimension that is "
"not being split");
}
} else {
if (offset & (dstTy.getDimSize(dim) - 1)) {
return emitError("The split offset may not touch the tile");
}
}
}
auto ctx = getContext();
auto ll = triton::gpu::toLinearLayout(srcTy);
auto kBlock = mlir::StringAttr::get(getContext(), "block");
if (ll.getInDimSize(kBlock) != 1) {
return emitError("non-trivial block dimension not supported");
}
auto llInv = ll.invert();
for (auto dim : splitDims) {
auto kDim = mlir::StringAttr::get(ctx, "dim" + llvm::Twine(dim));
llvm::SmallVector<std::pair<mlir::StringAttr, int32_t>> namedOffsets;
for (auto d : standardOutDimNames(ctx, srcTy.getRank())) {
namedOffsets.push_back({d, 0});
}
for (int dimSize = dstTy.getDimSize(dim); dimSize < srcTy.getDimSize(dim);
dimSize *= 2) {
namedOffsets[dim] = {kDim, dimSize};
if (!llvm::isPowerOf2_32(llInv.apply(namedOffsets)[0].second)) {
return emitError(
"We don't support splitting along the swizzling pattern");
}
}
}
return success();
}
RegionRange WarpSpecializeOp::getPartitionRegions() {
return cast<WarpSpecializePartitionsOp>(
getPartitionOpHolder().front().front())
.getPartitionRegions();
}
void WarpSpecializeOp::getSuccessorRegions(
RegionBranchPoint src, SmallVectorImpl<RegionSuccessor> &successors) {
if (src.isParent()) {
successors.emplace_back(&getDefaultRegion());
return;
}
assert(src.getRegionOrNull() == &getDefaultRegion());
successors.push_back(RegionSuccessor(getResults()));
}
LogicalResult WarpSpecializeOp::verify() {
if (!isa<WarpSpecializePartitionsOp>(
getPartitionOpHolder().front().front())) {
return emitOpError(
"expected to find only a `ttg.warp_specialize.partitions` op inside "
"its second region");
}
if (getPartitionRegions().size() != getPartitionNumWarps().size()) {
return emitOpError("has ") << getPartitionRegions().size()
<< " partitions but `partitionNumWarps` has "
<< getPartitionNumWarps().size() << " elements";
}
for (auto [i, numWarps] : llvm::enumerate(getPartitionNumWarps())) {
if (llvm::isPowerOf2_32(numWarps))
continue;
return emitOpError("partition #")
<< i << " number of warps (" << numWarps << ") must be a power of 2";
}
if (std::optional<ArrayRef<int32_t>> startIds = getWarpGroupStartIds()) {
if (startIds->size() != getPartitionNumWarps().size()) {
return emitOpError("has ")
<< startIds->size() << " warp group start IDs but expected "
<< getPartitionNumWarps().size();
}
}
for (auto [i, region] : llvm::enumerate(getPartitionRegions())) {
if (region->getNumArguments() != getNumOperands()) {
return emitOpError("partition region #")
<< i << " has " << region->getNumArguments()
<< " arguments but expected " << getNumOperands();
}
for (auto [argIdx, argType, capType] : llvm::enumerate(
region->getArgumentTypes(), getExplicitCaptures().getTypes())) {
if (argType == capType)
continue;
return emitOpError("partition region #")
<< i << " argument #" << argIdx << " has type " << argType
<< " but corresponding capture has type " << capType;
}
}
if ((*this)->getParentOfType<WarpSpecializeOp>()) {
return emitOpError(
"cannot be nested inside another `ttg.warp_specialize` op");
}
std::optional<int> numWarps = maybeLookupNumWarps(*this);
if (numWarps && *numWarps % 4 != 0) {
return mlir::emitError(getLoc()) << "warp-specialized kernels requires "
"num_warps to be a multiple of 4";
}
return success();
}
LogicalResult WarpSpecializeOp::canonicalize(WarpSpecializeOp op,
PatternRewriter &b) {
llvm::BitVector unusedArgs(op.getNumOperands());
llvm::BitVector unusedResults(op.getNumResults());
for (auto [i, result] : llvm::enumerate(op.getResults())) {
if (result.use_empty())
unusedResults.set(i);
}
DenseMap<Value, unsigned> uniqueCaptures;
for (auto [i, capture] : llvm::enumerate(op.getExplicitCaptures())) {
auto noUseInRegion = [i = i](Region *region) {
return region->getArgument(i).use_empty();
};
if (llvm::all_of(op.getPartitionRegions(), noUseInRegion)) {
unusedArgs.set(i);
continue;
}
auto [it, inserted] = uniqueCaptures.try_emplace(capture, i);
if (!inserted) {
unsigned duplicateIdx = it->second;
b.modifyOpInPlace(op, [&, i = i] {
for (Region *region : op.getPartitionRegions()) {
b.replaceAllUsesWith(region->getArgument(i),
region->getArgument(duplicateIdx));
}
});
unusedArgs.set(i);
}
}
if (unusedArgs.none() && unusedResults.none())
return failure();
if (unusedArgs.any()) {
b.modifyOpInPlace(op, [&] {
for (Region *region : op.getPartitionRegions())
region->front().eraseArguments(unusedArgs);
op->eraseOperands(unusedArgs);
});
}
if (unusedResults.any()) {
for (Block &block : op.getDefaultRegion()) {
if (auto yield = dyn_cast<WarpYieldOp>(block.getTerminator())) {
b.modifyOpInPlace(yield, [&] { yield->eraseOperands(unusedResults); });
}
}
SmallVector<Type> newTypes;
for (auto [i, type] : llvm::enumerate(op.getResultTypes())) {
if (!unusedResults.test(i))
newTypes.push_back(type);
}
OperationState state(op.getLoc(), op->getName(), op.getOperands(), newTypes,
op->getAttrs());
state.addRegion()->takeBody(op.getDefaultRegion());
state.addRegion()->takeBody(op.getPartitionOpHolder());
auto newOp = cast<WarpSpecializeOp>(b.create(state));
unsigned newResultIdx = 0;
for (auto [i, result] : llvm::enumerate(op.getResults())) {
if (!unusedResults.test(i))
result.replaceAllUsesWith(newOp.getResult(newResultIdx++));
}
assert(newResultIdx == newOp.getNumResults());
b.eraseOp(op);
}
return success();
}
void WarpSpecializeOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes,
ArrayRef<int32_t> partitionNumWarps,
unsigned partitionNumRegions) {
build(builder, state, resultTypes, ValueRange(),
partitionNumWarps, {}, {}, {});
OpBuilder::InsertionGuard guard(builder);
Block *container = builder.createBlock(state.regions.back().get());
builder.create<WarpSpecializePartitionsOp>(state.location,
partitionNumRegions);
}
void WarpSpecializeOp::build(OpBuilder &builder, OperationState &state,
TypeRange resultTypes, ValueRange explicitCaptures,
ArrayRef<int32_t> partitionNumWarps) {
build(builder, state, resultTypes, explicitCaptures, partitionNumWarps, {},
{}, {});
}
ParseResult WarpSpecializeOp::parse(OpAsmParser &p, OperationState &result) {
SmallVector<OpAsmParser::UnresolvedOperand> operands;
SMLoc operandLoc = p.getCurrentLocation();
if (p.parseOperandList(operands, AsmParser::Delimiter::Paren) ||
p.parseOptionalAttrDictWithKeyword(result.attributes) ||
p.parseKeyword("default") || p.parseRegion(*result.addRegion()))
return failure();
OperationState partitionOpState(
p.getEncodedSourceLoc(p.getCurrentLocation()),
WarpSpecializePartitionsOp::getOperationName());
SmallVector<int32_t> partitionNumWarps;
SmallVector<OpAsmParser::Argument> partitionArgs;
while (succeeded(p.parseOptionalKeyword(
("partition" + Twine(partitionNumWarps.size()).str())))) {
partitionArgs.clear();
SMLoc regionLoc = p.getCurrentLocation();
if (p.parseArgumentList(partitionArgs, AsmParser::Delimiter::Paren,
true) ||
p.parseKeyword("num_warps") || p.parseLParen() ||
p.parseInteger(partitionNumWarps.emplace_back()) || p.parseRParen() ||
p.parseRegion(*partitionOpState.addRegion(), partitionArgs))
return failure();
}
FunctionType types;
if (p.parseColon() || p.parseType(types) ||
p.resolveOperands(operands, types.getInputs(), operandLoc,
result.operands))
return failure();
result.addTypes(types.getResults());
result.addAttribute(getPartitionNumWarpsAttrName(result.name),
p.getBuilder().getDenseI32ArrayAttr(partitionNumWarps));
Block &holder = result.addRegion()->emplaceBlock();
OpBuilder b(p.getContext());
b.setInsertionPointToStart(&holder);
b.create(partitionOpState);
return success();
}
void WarpSpecializeOp::print(OpAsmPrinter &p) {
p << '(';
p.printOperands(getOperands());
p << ')';
p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs(),
{getPartitionNumWarpsAttrName()});
p.printNewline();
p << "default ";
p.printRegion(getDefaultRegion(), false);
for (auto [i, region, numWarps] :
llvm::enumerate(getPartitionRegions(), getPartitionNumWarps())) {
p.printNewline();
p << "partition" << i << '(';
llvm::interleaveComma(region->getArguments(), p, [&](BlockArgument arg) {
p.printRegionArgument(arg);
});
p << ") num_warps(" << numWarps << ") ";
p.printRegion(*region, false);
}
p << " : ";
p.printFunctionalType(*this);
}
LogicalResult WarpYieldOp::verify() {
if (getNumOperands() != getParentOp().getNumResults()) {
return emitOpError("has ")
<< getNumOperands() << " operands but parent op expected "
<< getParentOp().getNumResults();
}
for (auto [i, result, type] :
llvm::enumerate(getParentOp().getResultTypes(), getOperandTypes())) {
if (result != type) {
return emitOpError("operand #") << i << " has type " << type
<< " but parent op expected " << result;
}
}
return success();
}
static size_t getSharedMemorySize(Type type) {
if (isa<IntegerType, FloatType>(type))
return llvm::divideCeil(type.getIntOrFloatBitWidth(), 8);
if (isa<PointerType, TensorDescType>(type))
return 8;
if (auto desc = dyn_cast<MemDescType>(type)) {
if (!isa<SharedMemorySpaceAttr>(desc.getMemorySpace()))
return 8;
return 8 + desc.getRank() * 4;
}
llvm::report_fatal_error(
Twine("shared memory size for scalar type is unspecified: ") +
mlir::debugString(type));
}
std::pair<uint64_t, uint64_t> WarpSpecializeOp::getCaptureSizeAlign() {
uint64_t captureSize = 0;
for (Type type : getOperandTypes()) {
captureSize += getSharedMemorySize(type);
}
return {captureSize, 8};
}
unsigned WarpSpecializeOp::getTotalPartitionWarps() {
ArrayRef<int32_t> numWarps = getPartitionNumWarps();
return std::accumulate(numWarps.begin(), numWarps.end(), 0);
}
}