#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
using namespace mlir;
using namespace mlir::triton;
using namespace mlir::triton::gpu;
namespace {
class GatherOpConversion : public ConvertOpToLLVMPattern<GatherOp> {
public:
GatherOpConversion(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo, PatternBenefit benefit)
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
}
LogicalResult
matchAndRewrite(GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override;
private:
// Codegen the gather by storing the source tensor into shared memory and then
// gathering directly from shared memory.
void emitGatherInShared(GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const;
// Codegen a warp-local gather by shuffling elements across the warp and
// selecting from them.
void emitWarpLocalGather(GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const;
const TargetInfoBase &targetInfo;
};
LogicalResult
GatherOpConversion::matchAndRewrite(GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
GatherLoweringHelper helper(op);
// Specialize the lowering based on the source layout. Given that the cost of
// a warp shuffle is approximately half the cost of a roundtrip to shared
// memory with zero bank conflicts, we will need a more precise heuristic to
// choose between the two codegen paths and rely on the middle end to pick the
// right layout.
if (helper.isWarpLocal()) {
emitWarpLocalGather(op, adaptor, rewriter);
} else {
emitGatherInShared(op, adaptor, rewriter);
}
return success();
}
static Value convertIndexToI32(Location loc, Value index,
ConversionPatternRewriter &rewriter) {
auto b = TritonLLVMOpBuilder(loc, rewriter);
unsigned idxWidth = index.getType().getIntOrFloatBitWidth();
// The LL index computations are performed with 32 bit integers. If the
// indices are something else, cast them to i32.
if (idxWidth > 32) {
index = b.trunc(i32_ty, index);
} else if (idxWidth < 32) {
// Negative indices don't make sense, so zero-extend.
index = b.zext(i32_ty, index);
}
return index;
}
void GatherOpConversion::emitGatherInShared(
GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
Location loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
RankedTensorType srcType = op.getSrc().getType();
// Compute the src subtensor shape owned by this CTA.
SmallVector<unsigned> srcShapePerCTA =
convertType<unsigned>(triton::gpu::getShapePerCTA(srcType));
// Grab the src values in this thread.
SmallVector<Value> srcValues =
unpackLLElements(loc, adaptor.getSrc(), rewriter);
// Emit the indices of the src values owned by this thread.
SmallVector<SmallVector<Value>> srcIndices =
emitIndices(loc, rewriter, targetInfo, srcType.getEncoding(),
op.getSrc().getType(), /*withCTAOffset=*/true);
// Store the src values owned by the thread into their respective location in
// the scratch memory.
assert(srcValues.size() == srcIndices.size());
// Get the base pointer to the scratch memory.
Value smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op);
// For each src element owned by the thread, index into the scratch memory and
// then store it.
Type elemType = getTypeConverter()->convertType(srcType.getElementType());
for (auto [value, indices] : llvm::zip(srcValues, srcIndices)) {
// Convert the index at each dim into a single offset given the shape of the
// tensor.
Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA);
// Emit the offset into the shared memory and then store the value.
Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset);
b.store(value, ptr);
}
// Synchronize the whole CTA.
b.barrier();
// Grab the index values owned by this thread.
SmallVector<Value> idxValues =
unpackLLElements(loc, adaptor.getIndices(), rewriter);
// Apply the layout of the destination tensor to obtain the indices of the
// column to gather along, then for each column, replace the index along the
// gather axis with the appropriate index value.
//
// I = LL(pid)
// idx = indices[I]
// I_gather = [I[d] if d != axis else idx for d in range(len(I))]
// out[I] = src[I_gather]
RankedTensorType dstType = op.getType();
SmallVector<SmallVector<Value>> dstIndices =
emitIndices(loc, rewriter, targetInfo, dstType.getEncoding(), dstType,
/*withCTAOffset=*/true);
unsigned axis = op.getAxis();
SmallVector<Value> results(dstIndices.size());
for (auto [i, idx, indices] : llvm::enumerate(idxValues, dstIndices)) {
indices[axis] = convertIndexToI32(loc, idx, rewriter);
Value offset = LLVM::linearize(rewriter, loc, indices, srcShapePerCTA);
Value ptr = b.gep(smemBase.getType(), elemType, smemBase, offset);
results[i] = b.load(elemType, ptr);
}
Value packed =
packLLElements(loc, getTypeConverter(), results, rewriter, dstType);
rewriter.replaceOp(op, packed);
}
// High-level description of the algorithm:
//
// `isWarpLocal` checks that it is possible to compute each output element
// without data movement across warps.
//
// If the gather dim is `dimN`, then this means
//
// ll^-1(dimN)[(block, warp)] == 0
//
// for both source and index tensors: moving along the gather axis does not
// change the warp. Broadcasted layouts are not supported, so we know the
// layouts are permutation matrices.
//
// We can check this with `ll((block, warp))[dimN] == 0`.
//
// Let `gatherCol` be a tuple of all dimensions except the gather dimension.
// We also check that the gather columns line up the same way with respect to
// the warp between the source and index tensors with
//
// ll_src((block, warp))[gatherCol] == ll_idx((block, warp))[gatherCol]
//
// This means that for all index columns, the corresponding column in the source
// tensor is owned by the same warp.
//
// We also check
//
// ll_src(lane)[gatherCol] == ll_idx(lane)[gatherCol]
//
// This boils down to the fact that the algorithm essentially emits a series of
// index shuffles for each index value owned by each thread, and then a pile of
// selects to pick the right value. We need to figure out given an index value
// in a particular column, what are the source register values it could read
// from and who owns them.
//
// If this relationship did not hold, then the possible source registers for
// each index value varies with the thread, meaning the value operand provided
// to each shuffle index instruction would depend on the thread ID. This isn't a
// big deal. It just means would have to emit a pile of selects before each
// shuffle as well, to pick the right source register value. But we choose not
// to handle this.
//
// The codegen algorithm emits code:
// - Given the thread ID and a particular index tensor register, figure out
// which gather column it belongs to using a layout.
// - Using the index value itself as the value for `dimN`, use another layout to
// figure out which lane in the warp owns the desired value and which register
// in that lane it is.
// - For the gather column, figure out the source registers in that column, and
// for each of them, emit an index shuffle with the same computed lane ID.
// - Use the register component to select the right value from the shuffle
// results.
void GatherOpConversion::emitWarpLocalGather(
GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const {
MLIRContext *ctx = op.getContext();
Location loc = op.getLoc();
auto b = TritonLLVMOpBuilder(loc, rewriter);
RankedTensorType srcType = op.getSrc().getType();
RankedTensorType idxType = op.getIndices().getType();
// Layout dimension names.
StringAttr kBlock = str_attr("block");
StringAttr kWarp = str_attr("warp");
StringAttr kLane = str_attr("lane");
StringAttr kRegister = str_attr("register");
StringAttr kGatherDim = rewriter.getStringAttr("dim" + Twine(op.getAxis()));
SmallVector<StringAttr> allDims, otherDims;
for (unsigned dim = 0, rank = srcType.getRank(); dim < rank; ++dim) {
allDims.push_back(str_attr("dim" + Twine(dim)));
if (dim != op.getAxis()) {
otherDims.push_back(allDims.back());
}
}
// Compute the src and idx layouts.
LinearLayout srcLayout = toLinearLayout(srcType);
LinearLayout idxLayout = toLinearLayout(idxType);
// Let `ll_src` be the source layout and `ll_idx` be the index layout.
// Let `src_col` be a tuple of dimensions except the gather dimension,
// representing a specific column in the source tensor. Likewise for
// `idx_col`. Let `src_idx` be the index into gather dimension in the source
// tensor.
//
// `(src_lane, src_reg) = ll_src^-1(src_col, src_idx)`, where `src_lane` is
// the thread that contains the required element and `src_reg` is the register
// within that thread.
//
// Because `ll_src(block=0, warp=0, lane=0)[otherDims] ==
// ll_idx(0, 0, 0)[otherDims]`, we know given any `idx_reg` (element in the
// index tensor) the thread will need to read from the same column in the
// source tensor.
//
// Thus, we can obtain
//
// (src_lane, src_reg) = (ll_src^-1)(
// ll_idx(black, warp, lane, idx_reg)[otherDims],
// idxValues[idx_reg]
// )[{"lane", "register"}]
//
// And the mapping will be the correct for each thread.
//
// Given `src_reg \in [0, K*N)`, we just need to emit N index shuffles for
// each `idx_reg` (the number of index shuffles is quadratic!) and
// `llvm.select` using `src_reg` to get the right one. `K` is the number of
// elements per column owned by a thread.
// Invert the source layout. It doesn't matter whether it is fully invertible
// with respect to anything except the register input dimension, since we know
// those don't vary in ways that matter for codegen.
LinearLayout invSrcLayout = srcLayout.pseudoinvert();
// Sanity check: the warp must be invariant to the index because otherwise the
// gather would need to read across warps!
assert(invSrcLayout.sublayoutIsZero(kGatherDim, {kWarp, kBlock}) &&
"expected a warp-local gather");
invSrcLayout = invSrcLayout.sublayout(allDims, {kRegister, kLane});
LinearLayout idxColLayout =
idxLayout.sublayout({kBlock, kWarp, kLane, kRegister}, otherDims);
SmallVector<Value> srcValues =
unpackLLElements(loc, adaptor.getSrc(), rewriter);
SmallVector<Value> idxValues =
unpackLLElements(loc, adaptor.getIndices(), rewriter);
auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc);
Value blockId = targetInfo.getClusterCTAId(rewriter, loc);
unsigned /*N=*/srcRegsPerThread = srcLayout.getInDimSize(kRegister);
assert(srcRegsPerThread == srcValues.size());
// Given a index value, we need to know which sources register values it could
// index into. This is invariant to anything other than the register, which we
// checked already. Compute the full reverse map from
//
// idx_reg -> gather_column -> (src_reg0, src_reg1, ...)
//
LinearLayout invertSrcRegMap = invSrcLayout.sublayout(allDims, {kRegister});
// Remove zero bases in the gather dimension to make the function injective
// (for a given column) over the same codomain.
invertSrcRegMap = invertSrcRegMap.removeZeroBasesAlongDim(kGatherDim);
// We are left with only non-zero bases in the gather dimension, which means
// the number of registers per column is the size of the "gather dimension".
unsigned numRegsPerColumn = invertSrcRegMap.getInDimSize(kGatherDim);
// Get a map from idx_reg to the column it indexes into.
LinearLayout idxRegToCol = idxLayout.sublayout({kRegister}, otherDims);
// Now given `idx_reg`, we can compute the column it belongs to in both src
// and index tensors, then partially apply `invertSrcRegMap` with this to
// obtain a function that outputs the corresponding registers in the src
// tensor in the same column.
// L(column, i) = L(column, 0) xor L(0, i)
LinearLayout invertSrcRegMapColPart =
invertSrcRegMap.sublayout(otherDims, {kRegister});
LinearLayout invertSrcRegMapRest =
invertSrcRegMap.sublayout({kGatherDim}, {kRegister});
SmallVector<Value> results;
for (auto [idxReg, idxVal] : llvm::enumerate(idxValues)) {
SmallVector<std::pair<StringAttr, Value>> column =
applyLinearLayout(loc, rewriter, idxColLayout,
{{kRegister, b.i32_val(idxReg)},
{kLane, laneId},
{kWarp, warpId},
{kBlock, blockId}});
assert(column.size() == otherDims.size());
// Combine the computed column with the data-dependent gather index.
column.insert(column.begin() + op.getAxis(),
{kGatherDim, convertIndexToI32(loc, idxVal, rewriter)});
SmallVector<std::pair<StringAttr, Value>> srcLaneAndReg =
applyLinearLayout(loc, rewriter, invSrcLayout, column);
auto [srcRegName, srcReg] = srcLaneAndReg.front();
auto [srcLaneName, srcLane] = srcLaneAndReg.back();
assert(srcLaneName == kLane && srcRegName == kRegister);
assert(!srcValues.empty() && "can't gather from an empty tensor");
// Figure out which src registers we need to index shuffle from. This is
// invariant to anything else.
SmallVector<std::pair<StringAttr, int32_t>> normalizedColumn =
idxRegToCol.apply({{kRegister, idxReg}});
int32_t srcBase =
invertSrcRegMapColPart.apply(normalizedColumn).front().second;
Value result = b.undef(srcValues.front().getType());
for (unsigned i = 0; i != numRegsPerColumn; ++i) {
int32_t rest =
invertSrcRegMapRest.apply({{kGatherDim, i}}).front().second;
int32_t srcRegIdx = srcBase ^ rest;
Value value =
targetInfo.shuffleIdx(rewriter, loc, srcValues[srcRegIdx], srcLane);
result = b.select(b.icmp_eq(b.i32_val(srcRegIdx), srcReg), value, result);
}
results.push_back(result);
}
rewriter.replaceOp(op, packLLElements(loc, getTypeConverter(), results,
rewriter, op.getType()));
}
} // namespace
void triton::populateGatherOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit) {
patterns.insert<GatherOpConversion>(typeConverter, targetInfo, benefit);
}