* Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining
* a copy of this software and associated documentation files
* (the "Software"), to deal in the Software without restriction,
* including without limitation the rights to use, copy, modify, merge,
* publish, distribute, sublicense, and/or sell copies of the Software,
* and to permit persons to whom the Software is furnished to do so,
* subject to the following conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
#include <queue>
#include "mlir/Support/LLVM.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/ErrorHandling.h"
namespace ttg = mlir::triton::gpu;
namespace mlir {
namespace triton {
namespace nvidia_gpu {
#define GEN_PASS_DEF_TRITONGPUPLANCTAPASS
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h.inc"
namespace {
using CastOp = ::mlir::UnrealizedConversionCastOp;
unsigned getNumUsers(Value value) {
return std::distance(value.user_begin(), value.user_end());
}
Type replaceLayout(const Type &type, const Attribute &newLayout) {
Type curType = type;
auto ptrTy = dyn_cast<triton::PointerType>(curType);
if (ptrTy)
curType = ptrTy.getPointeeType();
if (auto tensorTy = dyn_cast<RankedTensorType>(curType))
curType = tensorTy.cloneWithEncoding(newLayout);
if (ptrTy)
curType = triton::PointerType::get(curType, ptrTy.getAddressSpace());
return curType;
}
ttg::DistributedEncodingTrait
replaceCTALayout(ttg::DistributedEncodingTrait layout,
llvm::ArrayRef<int64_t> shape, int numWarps,
const ttg::CTALayoutAttr &newCTALayout) {
if (auto blockedLayout = mlir::dyn_cast<ttg::BlockedEncodingAttr>(layout)) {
return ttg::BlockedEncodingAttr::get(
layout.getContext(), shape, blockedLayout.getSizePerThread(),
blockedLayout.getOrder(), numWarps, 32, newCTALayout);
} else if (auto sliceLayout =
mlir::dyn_cast<ttg::SliceEncodingAttr>(layout)) {
return ttg::SliceEncodingAttr::get(
layout.getContext(), sliceLayout.getDim(),
replaceCTALayout(sliceLayout.getParent(), shape, numWarps,
newCTALayout));
} else {
llvm::report_fatal_error("replaceCTALayout not implemented");
return layout;
}
}
class CTAPlanner {
public:
CTAPlanner(ClusterInfo *clusterInfo_);
~CTAPlanner();
void run(triton::FuncOp &funcOp);
private:
CastOp markBackward(CastOp cast) const;
CastOp markForward(CastOp cast) const;
bool isBackward(CastOp cast) const;
bool isForward(CastOp cast) const;
void setTiling(llvm::ArrayRef<unsigned> CTAsPerCGA);
bool processDot(triton::FuncOp &funcOp);
bool processReduce(triton::FuncOp &funcOp);
void processStoreLikeOps(triton::FuncOp &funcOp);
bool propagate(CastOp cast);
bool propagateBackward(CastOp cast);
bool propagateForward(CastOp cast);
void eraseCastOp(CastOp cast);
void eraseCastOpFromQueue(CastOp cast);
void eraseCastOpsFromQueue(llvm::ArrayRef<CastOp> casts);
void insertCasts(Operation *op, llvm::ArrayRef<Attribute> newOperandLayouts,
llvm::ArrayRef<Attribute> newResultLayouts);
void eliminateAdjacentCasts(CastOp cast0, CastOp cast1);
bool isLoadStoreOp(Operation *op) const;
bool processLoadStore(Operation *op, Attribute layout);
bool isElementwiseOp(Operation *op) const;
bool processElementwise(Operation *op, Attribute layout);
bool processConstant(arith::ConstantOp constant, Attribute layout);
bool processSplat(triton::SplatOp splat, Attribute layout);
bool processMakeRange(triton::MakeRangeOp makeRange, Attribute layout);
bool processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr,
Attribute layout);
bool processBroadcast(triton::BroadcastOp broadcast, Attribute layout);
bool processExpandDimsBackward(triton::ExpandDimsOp expandDims,
ttg::DistributedEncodingTrait newResultLayout);
bool processExpandDimsForward(triton::ExpandDimsOp expandDims,
ttg::DistributedEncodingTrait newSrcLayout);
bool processConvertLayoutBackward(ttg::ConvertLayoutOp convertLayout,
CastOp cast);
bool processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout,
CastOp cast);
bool processIfOp(scf::IfOp ifOp, int index, const Type &newType);
bool processForOp(scf::ForOp forOp, int index, const Type &newType);
bool processIfOpBackward(scf::IfOp ifOp, CastOp cast);
bool processForOpBackward(scf::ForOp forOp, CastOp cast);
bool processBlockArgBackward(BlockArgument arg, CastOp cast);
bool processForOpForward(scf::ForOp forOp, CastOp cast);
bool processYieldOpForward(scf::YieldOp yieldOp, CastOp cast);
bool processOpFallback(Operation *op);
bool processMultiUsersBackward(Value input, CastOp cast);
bool processMultiUsersForward(Value output, CastOp cast);
bool ownInfo;
ClusterInfo *clusterInfo;
bool tiled;
unsigned step;
unsigned stepUnchanged;
std::queue<CastOp> queue;
};
CTAPlanner::CTAPlanner(ClusterInfo *clusterInfo_)
: ownInfo(false), clusterInfo(clusterInfo_), tiled(false), step(0),
stepUnchanged(0) {
if (clusterInfo == nullptr) {
clusterInfo = new ClusterInfo();
ownInfo = true;
}
}
CTAPlanner::~CTAPlanner() {
if (ownInfo) {
delete clusterInfo;
ownInfo = false;
clusterInfo = nullptr;
}
}
void CTAPlanner::run(triton::FuncOp &funcOp) {
assert(!tiled && "Please create a new CTAPlanner");
static const unsigned maxSteps = 10000;
auto nextStep = [&]() {
++step;
assert(step < maxSteps && "Maximum number of steps exceeded");
};
processDot(funcOp);
nextStep();
processReduce(funcOp);
nextStep();
if (!tiled) {
processStoreLikeOps(funcOp);
nextStep();
}
while (!queue.empty()) {
CastOp cast = queue.front();
queue.pop();
bool changed = propagate(cast);
if (changed) {
stepUnchanged = 0;
} else {
queue.push(cast);
++stepUnchanged;
}
nextStep();
}
}
CastOp CTAPlanner::markBackward(CastOp cast) const {
cast->setAttr("direction", StringAttr::get(cast.getContext(), "backward"));
return cast;
}
CastOp CTAPlanner::markForward(CastOp cast) const {
cast->setAttr("direction", StringAttr::get(cast.getContext(), "forward"));
return cast;
}
bool CTAPlanner::isBackward(CastOp cast) const {
return cast->getAttrOfType<StringAttr>("direction") == "backward";
}
bool CTAPlanner::isForward(CastOp cast) const {
return cast->getAttrOfType<StringAttr>("direction") == "forward";
}
void CTAPlanner::setTiling(llvm::ArrayRef<unsigned> CTAsPerCGA) {
assert(!tiled && "CTA tiling is already determinted");
assert(clusterInfo && "ClusterInfo pointer is null");
tiled = true;
unsigned numCTAs = 1;
for (unsigned cta : CTAsPerCGA)
numCTAs *= cta;
if (numCTAs == 2) {
clusterInfo->clusterDimX = 2;
return;
}
if (CTAsPerCGA.size() > 0)
clusterInfo->clusterDimX = CTAsPerCGA[0];
if (CTAsPerCGA.size() > 1)
clusterInfo->clusterDimY = CTAsPerCGA[1];
if (CTAsPerCGA.size() > 2)
clusterInfo->clusterDimZ = CTAsPerCGA[2];
for (auto i = 3; i < CTAsPerCGA.size(); ++i)
if (CTAsPerCGA[i] != 1)
llvm::report_fatal_error("tiling > 3 dims is not implemented");
}
bool CTAPlanner::processDot(triton::FuncOp &funcOp) {
auto getCTATiling = [](int64_t M, int64_t N, int64_t K,
unsigned numCTAs) -> std::pair<unsigned, unsigned> {
unsigned chunk_m = 128;
auto isLegal = [](unsigned chunk) { return chunk >= 64; };
unsigned splitM, splitN;
for (; isLegal(chunk_m); chunk_m /= 2) {
splitM = std::clamp<unsigned>(M / chunk_m, 1, numCTAs);
splitN = numCTAs / splitM;
if (isLegal(N / splitN))
break;
}
return {splitM, splitN};
};
funcOp.walk([&](triton::DotOp dot) {
MLIRContext *ctx = dot.getContext();
auto aTy = cast<RankedTensorType>(dot.getA().getType());
auto bTy = cast<RankedTensorType>(dot.getB().getType());
auto dTy = cast<RankedTensorType>(dot.getD().getType());
assert(isa<ttg::DotOperandEncodingAttr>(aTy.getEncoding()) &&
isa<ttg::DotOperandEncodingAttr>(bTy.getEncoding()) &&
isa<ttg::BlockedEncodingAttr>(dTy.getEncoding()) &&
"PlanCTAPass should follow immediately after CoalescePass");
auto aLayout = cast<ttg::DotOperandEncodingAttr>(aTy.getEncoding());
auto bLayout = cast<ttg::DotOperandEncodingAttr>(bTy.getEncoding());
auto dLayout = cast<ttg::BlockedEncodingAttr>(dTy.getEncoding());
unsigned M = dTy.getShape()[0];
unsigned N = dTy.getShape()[1];
unsigned K = aTy.getShape()[1];
unsigned splitM, splitN;
std::tie(splitM, splitN) = getCTATiling(M, N, K, ttg::getNumCTAs(dLayout));
setTiling({splitM, splitN, 1});
OpBuilder builder(dot);
auto numThreads = ttg::lookupThreadsPerWarp(builder);
auto numWarps = ttg::lookupNumWarps(dot);
auto newCTALayout = ttg::CTALayoutAttr::get(ctx, {splitM, splitN},
{splitM, splitN}, {1, 0});
auto newDLayout = ttg::BlockedEncodingAttr::get(
ctx, dTy.getShape(), dLayout.getSizePerThread(), dLayout.getOrder(),
numWarps, numThreads, newCTALayout);
auto newALayout = ttg::DotOperandEncodingAttr::get(ctx, aLayout.getOpIdx(),
newDLayout, 0);
auto newBLayout = ttg::DotOperandEncodingAttr::get(ctx, bLayout.getOpIdx(),
newDLayout, 0);
insertCasts(dot.getOperation(), {newALayout, newBLayout, newDLayout},
{newDLayout});
});
return true;
}
bool CTAPlanner::processReduce(triton::FuncOp &funcOp) {
ModuleOp mod = funcOp->getParentOfType<ModuleOp>();
unsigned numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod);
funcOp.walk([&](triton::ReduceOp reduce) {
MLIRContext *context = reduce.getContext();
Value src = reduce.getOperands()[0];
unsigned axis = reduce.getAxis();
auto srcTy = cast<RankedTensorType>(src.getType());
auto srcShape = srcTy.getShape();
auto srcLayout = srcTy.getEncoding();
auto rank = srcShape.size();
auto order = ttg::getOrder(srcTy);
auto sizePerThread = ttg::getContigPerThread(srcTy);
auto CTAOrder = ttg::getCTAOrder(srcLayout);
llvm::SmallVector<unsigned> CTAsPerCGA(rank, 0);
unsigned remainingCTAs = numCTAs;
for (int i = rank - 1; i >= 0; --i) {
unsigned dim = order[i];
if (dim == axis) {
CTAsPerCGA[dim] = 1;
} else {
CTAsPerCGA[dim] = std::min<unsigned>(srcShape[dim] / sizePerThread[dim],
remainingCTAs);
remainingCTAs /= CTAsPerCGA[dim];
}
}
for (int i = rank - 1; i >= 0; --i) {
unsigned dim = order[i];
if (dim != axis) {
CTAsPerCGA[dim] *= remainingCTAs;
break;
}
}
llvm::SmallVector<unsigned> CTASplitNum = CTAsPerCGA;
if (remainingCTAs > 0)
CTAsPerCGA[order[rank - 1]] *= remainingCTAs;
auto numWarps = ttg::lookupNumWarps(reduce);
auto CTALayout =
ttg::CTALayoutAttr::get(context, CTAsPerCGA, CTASplitNum, CTAOrder);
if (!tiled)
setTiling(CTALayout.getCTAsPerCGA());
auto newSrcLayout =
replaceCTALayout(cast<ttg::DistributedEncodingTrait>(srcLayout),
srcShape, numWarps, CTALayout);
auto newResultLayout =
ttg::SliceEncodingAttr::get(context, axis, newSrcLayout);
unsigned numOperands = reduce.getNumOperands();
SmallVector<Attribute> newSrcLayoutVec(numOperands, newSrcLayout);
SmallVector<Attribute> newResultLayoutVec(numOperands, newResultLayout);
insertCasts(reduce.getOperation(), newSrcLayoutVec, newResultLayoutVec);
});
return true;
}
void CTAPlanner::processStoreLikeOps(triton::FuncOp &funcOp) {
assert(!tiled && "CTA tiling is already determinted");
llvm::SmallVector<Operation *> stores;
funcOp.walk([&](Operation *op) {
if (llvm::isa<triton::StoreOp, triton::AtomicRMWOp, triton::AtomicCASOp,
triton::DescriptorStoreLikeOpInterface>(op))
stores.push_back(op);
});
assert(stores.size() > 0 && "Cannot find store-like ops");
auto numWarps = ttg::lookupNumWarps(funcOp);
ttg::CTALayoutAttr CTALayout;
for (Operation *store : stores) {
auto val = [store]() -> Value {
if (auto descStore =
dyn_cast<triton::DescriptorStoreLikeOpInterface>(store))
return descStore.getSrc();
return store->getOperand(0);
}();
if (auto tensorTy = dyn_cast<RankedTensorType>(val.getType())) {
if (!tiled) {
CTALayout = ttg::getCTALayout(tensorTy.getEncoding());
setTiling(CTALayout.getCTAsPerCGA());
}
auto newLayout = replaceCTALayout(
cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding()),
tensorTy.getShape(), numWarps, CTALayout);
processElementwise(store, newLayout);
}
}
if (!tiled)
setTiling({1, 1, 1});
}
bool CTAPlanner::propagate(CastOp cast) {
return isBackward(cast) ? propagateBackward(cast) : propagateForward(cast);
}
bool CTAPlanner::propagateBackward(CastOp cast) {
Value input = cast.getOperand(0);
Value output = cast.getResult(0);
unsigned numUsers = getNumUsers(input);
if (numUsers == 0) {
llvm::report_fatal_error("Unreachable branch");
return false;
} else if (numUsers == 1) {
Type outTy = output.getType();
if (auto ptrTy = dyn_cast<triton::PointerType>(outTy))
outTy = ptrTy.getPointeeType();
auto layout = mlir::cast<ttg::DistributedEncodingTrait>(
mlir::cast<RankedTensorType>(outTy).getEncoding());
Operation *op = input.getDefiningOp();
if (op == nullptr) {
assert(isa<BlockArgument>(input) &&
"Unexpected Value without defining op");
processBlockArgBackward(llvm::cast<BlockArgument>(input), cast);
} else if (auto prevCast = llvm::dyn_cast<CastOp>(op)) {
eliminateAdjacentCasts(prevCast, cast);
} else if (isLoadStoreOp(op)) {
processLoadStore(op, layout);
} else if (isElementwiseOp(op)) {
processElementwise(op, layout);
} else if (auto constant = llvm::dyn_cast<arith::ConstantOp>(op)) {
processConstant(constant, layout);
} else if (auto splat = llvm::dyn_cast<triton::SplatOp>(op)) {
processSplat(splat, layout);
} else if (auto makeRange = llvm::dyn_cast<triton::MakeRangeOp>(op)) {
processMakeRange(makeRange, layout);
} else if (auto makeTensorPtr =
llvm::dyn_cast<triton::MakeTensorPtrOp>(op)) {
processMakeTensorPtr(makeTensorPtr, layout);
} else if (llvm::isa<triton::AdvanceOp>(op)) {
processElementwise(op, layout);
} else if (auto broadcast = llvm::dyn_cast<triton::BroadcastOp>(op)) {
processBroadcast(broadcast, layout);
} else if (auto expandDims = llvm::dyn_cast<triton::ExpandDimsOp>(op)) {
processExpandDimsBackward(expandDims, layout);
} else if (auto ifOp = llvm::dyn_cast<scf::IfOp>(op)) {
processIfOpBackward(ifOp, cast);
} else if (auto forOp = llvm::dyn_cast<scf::ForOp>(op)) {
processForOpBackward(forOp, cast);
} else if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(op)) {
return processConvertLayoutBackward(convertLayout, cast);
} else {
return processOpFallback(op);
}
return true;
} else {
return processMultiUsersBackward(input, cast);
}
}
bool CTAPlanner::propagateForward(CastOp cast) {
Value input = cast.getOperand(0);
Value output = cast.getResult(0);
unsigned numUsers = getNumUsers(output);
if (numUsers == 0) {
cast.erase();
} else if (numUsers == 1) {
Type inTy = input.getType();
if (auto ptrTy = dyn_cast<triton::PointerType>(inTy))
inTy = ptrTy.getPointeeType();
Attribute layout = mlir::cast<RankedTensorType>(inTy).getEncoding();
Operation *op = *output.user_begin();
if (auto nextCast = llvm::dyn_cast<CastOp>(op)) {
eliminateAdjacentCasts(cast, nextCast);
} else if (isLoadStoreOp(op)) {
processLoadStore(op, layout);
} else if (isElementwiseOp(op)) {
processElementwise(op, layout);
} else if (llvm::isa<triton::AdvanceOp>(op)) {
processElementwise(op, layout);
} else if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(op)) {
return processConvertLayoutForward(convertLayout, cast);
} else if (auto forOp = llvm::dyn_cast<scf::ForOp>(op)) {
processForOpForward(forOp, cast);
} else if (auto yieldOp = llvm::dyn_cast<scf::YieldOp>(op)) {
processYieldOpForward(yieldOp, cast);
} else {
return processOpFallback(op);
}
} else {
processMultiUsersForward(output, cast);
}
return true;
}
void CTAPlanner::eraseCastOp(CastOp cast) {
Value output = cast.getResult(0);
assert(getNumUsers(output) == 0 &&
"Cannot erase CastOp because it is still in use");
cast.erase();
}
void CTAPlanner::eraseCastOpFromQueue(CastOp cast) {
eraseCastOpsFromQueue({cast});
}
void CTAPlanner::eraseCastOpsFromQueue(llvm::ArrayRef<CastOp> casts) {
llvm::DenseSet<CastOp> erased;
for (CastOp cast : casts) {
eraseCastOp(cast);
erased.insert(cast);
}
decltype(queue) tempQueue;
std::swap(queue, tempQueue);
while (!tempQueue.empty()) {
auto cast = tempQueue.front();
tempQueue.pop();
if (!erased.contains(cast))
queue.push(cast);
}
}
void CTAPlanner::insertCasts(Operation *op,
llvm::ArrayRef<Attribute> newOperandLayouts,
llvm::ArrayRef<Attribute> newResultLayouts) {
assert(op->getNumOperands() == newOperandLayouts.size() &&
"NumOperands mismatched");
assert(op->getNumResults() == newResultLayouts.size() &&
"NumResults mismatched");
Location loc = op->getLoc();
OpBuilder builder(op->getContext());
builder.setInsertionPoint(op);
for (unsigned i = 0; i < op->getNumOperands(); ++i) {
Value operand = op->getOperand(i);
auto operandTy = operand.getType();
if (triton::isTensorOrTensorPointerType(operandTy)) {
operandTy = replaceLayout(operandTy, newOperandLayouts[i]);
auto cast = markBackward(builder.create<CastOp>(loc, operandTy, operand));
op->setOperand(i, cast.getResult(0));
queue.push(cast);
}
}
builder.setInsertionPointAfter(op);
for (unsigned i = 0; i < op->getNumResults(); ++i) {
Value result = op->getResult(i);
auto resultTy = result.getType();
if (triton::isTensorOrTensorPointerType(resultTy)) {
resultTy = replaceLayout(resultTy, newResultLayouts[i]);
auto cast =
markForward(builder.create<CastOp>(loc, result.getType(), result));
result.setType(resultTy);
result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation());
queue.push(cast);
}
}
}
void CTAPlanner::eliminateAdjacentCasts(CastOp cast0, CastOp cast1) {
assert(cast0.getResult(0) == cast1.getOperand(0) &&
"The two casts are not adjacent");
assert(isForward(cast0) && isBackward(cast1) &&
"Expected pattern of adjacent casts: forward + backward");
Value input = cast0.getOperand(0);
Value output = cast1.getResult(0);
if (input.getType() == output.getType()) {
output.replaceAllUsesWith(input);
eraseCastOpsFromQueue({cast1, cast0});
} else {
OpBuilder builder(cast1.getOperation());
auto cvt = builder.create<ttg::ConvertLayoutOp>(cast1.getLoc(),
output.getType(), input);
output.replaceAllUsesWith(cvt.getResult());
eraseCastOpsFromQueue({cast1, cast0});
}
}
bool CTAPlanner::isLoadStoreOp(Operation *op) const {
return llvm::isa<triton::LoadOp, triton::StoreOp, triton::AtomicRMWOp,
triton::AtomicCASOp, triton::DescriptorLoadOp,
triton::DescriptorStoreLikeOpInterface,
triton::DescriptorGatherOp>(op);
}
bool CTAPlanner::processLoadStore(Operation *op, Attribute layout) {
if (auto sliceLayout = mlir::dyn_cast<ttg::SliceEncodingAttr>(layout)) {
auto dim = sliceLayout.getDim();
auto CTAsPerCGA = ttg::getCTAsPerCGA(sliceLayout.getParent());
if (CTAsPerCGA[dim] > 1) {
Value val =
op->getNumResults() > 0 ? op->getResult(0) : op->getOperand(0);
Attribute originalLayout =
cast<RankedTensorType>(val.getType()).getEncoding();
return processLoadStore(op, originalLayout);
}
}
auto CTALayout = ttg::getCTALayout(layout);
auto numWarps = ttg::lookupNumWarps(op);
llvm::SmallVector<Attribute> newOperandLayouts;
for (unsigned i = 0; i < op->getNumOperands(); ++i) {
auto type = op->getOperand(i).getType();
if (auto ptrTy = dyn_cast<triton::PointerType>(type))
type = ptrTy.getPointeeType();
auto tensorTy = dyn_cast<RankedTensorType>(type);
if (!tensorTy) {
newOperandLayouts.push_back(Attribute());
continue;
}
auto oldLayout =
cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding());
auto newLayout =
replaceCTALayout(oldLayout, tensorTy.getShape(), numWarps, CTALayout);
newOperandLayouts.push_back(newLayout);
}
llvm::SmallVector<Attribute> newResultLayouts;
for (unsigned i = 0; i < op->getNumResults(); ++i) {
auto type = op->getResult(i).getType();
if (auto ptrTy = dyn_cast<triton::PointerType>(type))
type = ptrTy.getPointeeType();
auto tensorTy = cast<RankedTensorType>(type);
auto oldLayout =
cast<ttg::DistributedEncodingTrait>(tensorTy.getEncoding());
auto newLayout =
replaceCTALayout(oldLayout, tensorTy.getShape(), numWarps, CTALayout);
newResultLayouts.push_back(newLayout);
}
insertCasts(op, newOperandLayouts, newResultLayouts);
return true;
}
bool CTAPlanner::isElementwiseOp(Operation *op) const {
if (llvm::isa<arith::AddFOp, arith::AddIOp, arith::AndIOp, arith::CeilDivSIOp,
arith::CeilDivUIOp, arith::DivFOp, arith::DivSIOp,
arith::DivUIOp, arith::ExtFOp, arith::ExtSIOp, arith::ExtUIOp,
arith::FloorDivSIOp, arith::FPToSIOp, arith::FPToUIOp,
arith::MaximumFOp, arith::MaxNumFOp, arith::MaxSIOp,
arith::MaxUIOp, arith::MinimumFOp, arith::MinNumFOp,
arith::MinSIOp, arith::MinUIOp, arith::MulFOp, arith::MulIOp,
arith::MulUIExtendedOp, arith::MulSIExtendedOp, arith::NegFOp,
arith::OrIOp, arith::RemFOp, arith::RemSIOp, arith::RemUIOp,
arith::ShLIOp, arith::ShRSIOp, arith::ShRUIOp, arith::SIToFPOp,
arith::SubFOp, arith::SubIOp, arith::TruncFOp, arith::TruncIOp,
arith::UIToFPOp, arith::XOrIOp>(op))
return true;
if (llvm::isa<math::AbsFOp, math::AbsIOp, math::AtanOp, math::Atan2Op,
math::CeilOp, math::CopySignOp, math::CosOp, math::SinOp,
math::CountLeadingZerosOp, math::CountTrailingZerosOp,
math::CtPopOp, math::ErfOp, math::ExpOp, math::Exp2Op,
math::FloorOp, math::ExpM1Op, math::FmaOp, math::LogOp,
math::Log10Op, math::Log1pOp, math::Log2Op, math::PowFOp,
math::SqrtOp, math::RsqrtOp, math::TanhOp>(op))
return true;
if (llvm::isa<triton::IntToPtrOp, triton::PtrToIntOp, triton::BitcastOp,
triton::FpToFpOp, triton::AddPtrOp, triton::PreciseSqrtOp,
triton::PreciseDivFOp>(op))
return true;
if (auto externElementwiseOp = dyn_cast<triton::ExternElementwiseOp>(op))
return externElementwiseOp.getPure();
if (llvm::isa<arith::CmpIOp, arith::CmpFOp, arith::SelectOp>(op))
return true;
return false;
}
bool CTAPlanner::processElementwise(Operation *op, Attribute layout) {
llvm::SmallVector<Attribute> newOperandLayouts(op->getNumOperands(), layout);
llvm::SmallVector<Attribute> newResultLayouts(op->getNumResults(), layout);
insertCasts(op, newOperandLayouts, newResultLayouts);
return true;
}
bool CTAPlanner::processConstant(arith::ConstantOp constant, Attribute layout) {
if (auto tensorTy = dyn_cast<RankedTensorType>(constant.getType())) {
if (auto attr = dyn_cast<SplatElementsAttr>(constant.getValue())) {
auto newTensorTy = tensorTy.cloneWithEncoding(layout);
constant.setValueAttr(
SplatElementsAttr::get(newTensorTy, attr.getSplatValue<Attribute>()));
}
}
insertCasts(constant.getOperation(), {}, {layout});
return true;
}
bool CTAPlanner::processSplat(triton::SplatOp splat, Attribute layout) {
insertCasts(splat.getOperation(), {{}}, {layout});
return true;
}
bool CTAPlanner::processMakeRange(triton::MakeRangeOp makeRange,
Attribute layout) {
insertCasts(makeRange.getOperation(), {}, {layout});
return true;
}
bool CTAPlanner::processMakeTensorPtr(triton::MakeTensorPtrOp makeTensorPtr,
Attribute layout) {
llvm::SmallVector<Attribute> dummyInAttrs(makeTensorPtr.getNumOperands(), {});
insertCasts(makeTensorPtr.getOperation(), dummyInAttrs, {layout});
return true;
}
bool CTAPlanner::processBroadcast(triton::BroadcastOp broadcast,
Attribute layout) {
insertCasts(broadcast.getOperation(), {layout}, {layout});
return true;
}
bool CTAPlanner::processExpandDimsBackward(
triton::ExpandDimsOp expandDims,
ttg::DistributedEncodingTrait newResultLayout) {
auto newSrcLayout = ttg::SliceEncodingAttr::get(
newResultLayout.getContext(), expandDims.getAxis(), newResultLayout);
insertCasts(expandDims.getOperation(), {newSrcLayout}, {newResultLayout});
return true;
}
bool CTAPlanner::processExpandDimsForward(
triton::ExpandDimsOp expandDims,
ttg::DistributedEncodingTrait newSrcLayout) {
llvm::report_fatal_error("processExpandDimsForward not implemented yet");
return true;
}
bool CTAPlanner::processConvertLayoutBackward(
ttg::ConvertLayoutOp convertLayout, CastOp cast) {
Value src = convertLayout.getSrc();
Value result = convertLayout.getResult();
assert(getNumUsers(result) == 1 &&
"Expect to call processMultiUsersBackward first");
result.replaceAllUsesWith(src);
convertLayout.erase();
queue.push(cast);
return true;
}
bool CTAPlanner::processConvertLayoutForward(ttg::ConvertLayoutOp convertLayout,
CastOp cast) {
Value src = convertLayout.getSrc();
Value result = convertLayout.getResult();
assert(getNumUsers(src) == 1 &&
"Expect to call processMultiUsersForward first");
src.setType(result.getType());
result.replaceAllUsesWith(src);
convertLayout.erase();
queue.push(cast);
return true;
}
bool CTAPlanner::processIfOp(scf::IfOp ifOp, int index, const Type &newType) {
assert(index < ifOp.getNumResults() && "Invalid result index of IfOp");
assert(index < ifOp.thenYield().getNumOperands() &&
"Invalid operand index of YieldOp");
assert(index < ifOp.elseYield().getNumOperands() &&
"Invalid operand index of YieldOp");
Location loc = ifOp.getLoc();
OpBuilder builder(ifOp.getContext());
Value result = ifOp.getResult(index);
builder.setInsertionPointAfter(ifOp.getOperation());
auto newCast =
markForward(builder.create<CastOp>(loc, result.getType(), result));
result.setType(newType);
result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation());
queue.push(newCast);
for (scf::YieldOp yield : {ifOp.thenYield(), ifOp.elseYield()}) {
Value yieldSrc = yield.getOperand(index);
builder.setInsertionPoint(yield.getOperation());
newCast = markBackward(builder.create<CastOp>(loc, newType, yieldSrc));
yield->setOperand(index, newCast.getResult(0));
queue.push(newCast);
}
return true;
}
bool CTAPlanner::processForOp(scf::ForOp forOp, int index,
const Type &newType) {
Block *body = forOp.getBody();
auto yield = llvm::cast<scf::YieldOp>(forOp.getBody()->getTerminator());
assert(index + forOp.getNumControlOperands() < forOp.getNumOperands() &&
"Invalid operand index of ForOp");
assert(index + forOp.getNumInductionVars() < body->getNumArguments() &&
"Invalid block arg index of ForOp");
assert(index < yield.getNumOperands() && "Invalid operand index of YieldOp");
assert(index < forOp.getNumResults() && "Invalid result index of IfOp");
Location loc = forOp.getLoc();
OpBuilder builder(forOp.getContext());
OpOperand &operand =
forOp->getOpOperand(index + forOp.getNumControlOperands());
builder.setInsertionPoint(forOp.getOperation());
auto newCast =
markBackward(builder.create<CastOp>(loc, newType, operand.get()));
operand.set(newCast.getResult(0));
queue.push(newCast);
Value arg = body->getArgument(index + forOp.getNumInductionVars());
builder.setInsertionPointToStart(body);
newCast = markForward(builder.create<CastOp>(loc, arg.getType(), arg));
arg.setType(newType);
arg.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation());
queue.push(newCast);
Value yieldSrc = yield.getOperand(index);
builder.setInsertionPoint(yield.getOperation());
newCast = markBackward(builder.create<CastOp>(loc, newType, yieldSrc));
yield->setOperand(index, newCast.getResult(0));
queue.push(newCast);
Value result = forOp.getResult(index);
builder.setInsertionPointAfter(forOp.getOperation());
newCast = markForward(builder.create<CastOp>(loc, result.getType(), result));
result.setType(newType);
result.replaceAllUsesExcept(newCast.getResult(0), newCast.getOperation());
queue.push(newCast);
return true;
}
int findResultIndex(Operation *op, Value result) {
for (int i = 0; i < op->getNumResults(); ++i)
if (op->getResult(i) == result)
return i;
llvm::report_fatal_error("Invalid index of op result");
return -1;
}
bool CTAPlanner::processIfOpBackward(scf::IfOp ifOp, CastOp cast) {
int index = findResultIndex(ifOp.getOperation(), cast.getOperand(0));
auto newType = cast.getResult(0).getType();
return processIfOp(ifOp, index, newType);
}
bool CTAPlanner::processForOpBackward(scf::ForOp forOp, CastOp cast) {
int index = findResultIndex(forOp.getOperation(), cast.getOperand(0));
auto newType = cast.getResult(0).getType();
return processForOp(forOp, index, newType);
}
bool CTAPlanner::processBlockArgBackward(BlockArgument arg, CastOp cast) {
if (auto forOp = llvm::dyn_cast<scf::ForOp>(arg.getOwner()->getParentOp())) {
int index = int(arg.getArgNumber()) - forOp.getNumInductionVars();
auto newType = cast.getResult(0).getType();
return processForOp(forOp, index, newType);
} else {
llvm::report_fatal_error("Unexpected parent op of block argument");
return true;
}
}
bool CTAPlanner::processForOpForward(scf::ForOp forOp, CastOp cast) {
int index = cast.getResult(0).use_begin()->getOperandNumber() -
forOp.getNumControlOperands();
auto newType = cast.getOperand(0).getType();
return processForOp(forOp, index, newType);
}
bool CTAPlanner::processYieldOpForward(scf::YieldOp yieldOp, CastOp cast) {
int index = cast.getResult(0).use_begin()->getOperandNumber();
auto newType = cast.getOperand(0).getType();
if (auto ifOp = llvm::dyn_cast<scf::IfOp>(yieldOp->getParentOp()))
return processIfOp(ifOp, index, newType);
else if (auto forOp = llvm::dyn_cast<scf::ForOp>(yieldOp->getParentOp()))
return processForOp(forOp, index, newType);
else
llvm::report_fatal_error("Unexpected parent op of YieldOp");
return true;
}
bool CTAPlanner::processOpFallback(Operation *op) {
Location loc = op->getLoc();
OpBuilder builder(op->getContext());
builder.setInsertionPoint(op);
for (unsigned i = 0; i < op->getNumOperands(); ++i) {
Value operand = op->getOperand(i);
auto operandTy = operand.getType();
if (triton::isTensorOrTensorPointerType(operandTy)) {
auto cast = markBackward(builder.create<CastOp>(loc, operandTy, operand));
op->setOperand(i, cast.getResult(0));
queue.push(cast);
}
}
builder.setInsertionPointAfter(op);
for (unsigned i = 0; i < op->getNumResults(); ++i) {
Value result = op->getResult(i);
auto resultTy = result.getType();
if (triton::isTensorOrTensorPointerType(resultTy)) {
auto cast = markForward(builder.create<CastOp>(loc, resultTy, result));
result.replaceAllUsesExcept(cast.getResult(0), cast.getOperation());
queue.push(cast);
}
}
return true;
}
bool CTAPlanner::processMultiUsersBackward(Value input, CastOp cast) {
Location loc = input.getLoc();
OpBuilder builder(input.getContext());
llvm::DenseMap<Type, llvm::SmallVector<CastOp>> typeToIndices;
for (OpOperand &operand : input.getUses()) {
auto brotherCast = llvm::dyn_cast<CastOp>(operand.getOwner());
if (!brotherCast) {
if (stepUnchanged <= queue.size())
return false;
builder.setInsertionPoint(operand.getOwner());
brotherCast = markBackward(
builder.create<CastOp>(loc, cast.getResult(0).getType(), input));
auto newCast = markForward(builder.create<CastOp>(
loc, input.getType(), brotherCast.getResult(0)));
operand.set(newCast.getResult(0));
queue.push(brotherCast);
queue.push(newCast);
}
auto type = brotherCast.getResult(0).getType();
typeToIndices[type].push_back(brotherCast);
}
bool first = true;
for (auto it : typeToIndices) {
Type &type = it.first;
llvm::SmallVector<CastOp> &casts = it.second;
Value newInput = input;
if (!first) {
if (Operation *defOp = input.getDefiningOp()) {
builder.setInsertionPointAfter(defOp);
Operation *clonedOp = builder.clone(*defOp);
newInput = clonedOp->getResult(0);
} else {
llvm::report_fatal_error("Layout conflict for block arg");
return false;
}
}
first = false;
if (Operation *defOp = newInput.getDefiningOp()) {
builder.setInsertionPointAfter(defOp);
} else {
assert(isa<BlockArgument>(newInput) &&
"Unexpected Value without defining op");
builder.setInsertionPointToStart(
llvm::cast<BlockArgument>(newInput).getOwner());
}
auto newCast = markBackward(builder.create<CastOp>(loc, type, newInput));
queue.push(newCast);
auto newResult = newCast.getResult(0);
for (CastOp &brotherCast : casts) {
brotherCast.getResult(0).replaceAllUsesWith(newResult);
eraseCastOpFromQueue(brotherCast);
}
}
return true;
}
bool CTAPlanner::processMultiUsersForward(Value castResult, CastOp cast) {
Value castSrc = cast.getOperand(0);
Location loc = cast.getLoc();
OpBuilder builder(cast.getContext());
builder.setInsertionPointAfter(cast.getOperation());
while (!castResult.use_empty()) {
auto newCast =
markForward(builder.create<CastOp>(loc, castResult.getType(), castSrc));
castResult.use_begin()->set(newCast.getResult(0));
queue.push(newCast);
}
eraseCastOp(cast);
return true;
}
}
struct PlanCTAPass : public impl::TritonGPUPlanCTAPassBase<PlanCTAPass> {
PlanCTAPass(ClusterInfo *clusterInfo_ = nullptr)
: clusterInfo(clusterInfo_) {}
void runOnOperation() override {
ModuleOp mod = getOperation();
if (ttg::TritonGPUDialect::getNumCTAs(mod) == 1)
return;
mod.walk([&](triton::FuncOp funcOp) {
CTAPlanner planner(clusterInfo);
planner.run(funcOp);
OpBuilder builder(funcOp);
builder.clone(*funcOp.getOperation());
funcOp.erase();
});
}
ClusterInfo *clusterInfo;
};
std::unique_ptr<Pass>
createTritonNvidiaGPUPlanCTAPass(ClusterInfo *clusterInfo) {
return std::make_unique<PlanCTAPass>(clusterInfo);
}
}
}
}
* - Use ConvertLayoutOp instead of UnrealizedConversionCastOp.
* - Move PlanCTAPass to the front of CoalescePass.
* - Design better tiling strategy for DotOp and ReduceOp.
* - Consider cases where there are more than one DotOps.
* - Use better data structure for erasing CastOps from queue (linked list?).
* - Process eliminable CastOps in higher priority.
* - Fix the clone func bug in PlanCTAPass::runOnOperation.
* - Add some comments to introduce the overall idea of this pass.
* - Add some lit tests for this pass.
*/