#include "Dialect/TritonAMDGPU/IR/Dialect.h"
#include "TritonAMDGPUTransforms/Passes.h"
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/WalkPatternRewriteDriver.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/Support/Debug.h"
#define DEBUG_TYPE "tritonamdgpu-in-thread-transpose"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttag = mlir::triton::amdgpu;
namespace mlir {
#define GEN_PASS_DEF_TRITONAMDGPUINTHREADTRANSPOSE
#include "TritonAMDGPUTransforms/Passes.h.inc"
namespace {
static Type replaceEncoding(Type type, Attribute encoding) {
RankedTensorType tensorType = cast<RankedTensorType>(type);
return RankedTensorType::get(tensorType.getShape(),
tensorType.getElementType(), encoding);
}
void refineGlobalLoadLayout(PatternRewriter &rewriter, Attribute encoding,
tt::LoadOp load) {
auto loc = load->getLoc();
rewriter.setInsertionPoint(load);
SmallVector<Value, 4> newArgs;
for (auto operand : load->getOperands()) {
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
if (tensorType) {
Type newType = replaceEncoding(tensorType, encoding);
newArgs.push_back(
rewriter.create<ttg::ConvertLayoutOp>(loc, newType, operand));
} else {
newArgs.push_back(operand);
}
}
auto attrs = load->getAttrs();
auto newLoad = rewriter.create<tt::LoadOp>(loc, newArgs, attrs);
auto loadType = load.getType();
Value newResult = newLoad.getResult();
rewriter.replaceOpWithNewOp<ttg::ConvertLayoutOp>(load, loadType, newResult);
}
void transposeInRegsitersBeforeStoreInLocalMemory(
PatternRewriter &rewriter, Operation *memStoreOp,
ArrayRef<int64_t> loadShape, ttg::BlockedEncodingAttr newLoadEncoding) {
assert((mlir::isa<ttg::LocalAllocOp, ttg::LocalStoreOp>(memStoreOp)));
if (memStoreOp->getNumOperands() == 0)
return;
auto data = memStoreOp->getOperand(0);
rewriter.setInsertionPoint(memStoreOp);
auto transposedLayout =
ttag::InThreadTransposeOp::deduceOutputLayout(loadShape, newLoadEncoding);
auto transposedEncoding =
ttg::LinearEncodingAttr::get(memStoreOp->getContext(), transposedLayout);
auto loc = memStoreOp->getLoc();
auto newLoadType = replaceEncoding(data.getType(), newLoadEncoding);
auto nonTransposed =
rewriter.create<ttg::ConvertLayoutOp>(loc, newLoadType, data);
auto transposedType = replaceEncoding(data.getType(), transposedEncoding);
auto inThreadTransposed = rewriter.create<ttag::InThreadTransposeOp>(
loc, transposedType, nonTransposed);
rewriter.startOpModification(memStoreOp);
memStoreOp->setOperand(0, inThreadTransposed);
rewriter.finalizeOpModification(memStoreOp);
}
Attribute createNewSharedEncoding(RankedTensorType operandType) {
auto ctx = operandType.getContext();
auto dotOperandEnc =
cast<ttg::DotOperandEncodingAttr>(operandType.getEncoding());
auto ctaLayout = ttg::getCTALayout(dotOperandEnc);
auto bitWidth = operandType.getElementTypeBitWidth();
SmallVector<unsigned> order{1, 0};
if (dotOperandEnc.getOpIdx() == 1)
std::swap(order[0], order[1]);
auto tempAttr = ttg::SwizzledSharedEncodingAttr::get(
ctx, dotOperandEnc, operandType.getShape(), order, ctaLayout, bitWidth,
false);
auto sharedVec = tempAttr.getVec();
auto perPhase = tempAttr.getPerPhase();
auto maxPhase = tempAttr.getMaxPhase();
auto newSharedEnc = ttg::AMDRotatingSharedEncodingAttr::get(
ctx, sharedVec, perPhase, maxPhase, order, ctaLayout);
return newSharedEnc;
}
void changeSharedEncoding(PatternRewriter &rewriter, Value memVal,
Attribute newEncoding) {
auto originalType = cast<ttg::MemDescType>(memVal.getType());
auto sharedEnc =
dyn_cast<ttg::SwizzledSharedEncodingAttr>(originalType.getEncoding());
if (!sharedEnc)
return;
auto newType = ttg::MemDescType::get(
originalType.getShape(), originalType.getElementType(), newEncoding,
originalType.getMemorySpace(), originalType.getMutableMemory());
auto parentOp = memVal.getParentBlock()->getParentOp();
rewriter.startOpModification(parentOp);
memVal.setType(newType);
rewriter.finalizeOpModification(parentOp);
}
struct GlobalToSharedMemoryOpChain {
SetVector<tt::LoadOp> globalLoads;
SetVector<Operation *> localAllocStores;
SmallVector<Value> sharedMemVals;
};
FailureOr<SmallVector<Value>>
traverseCFForValueDefs(Value val, SetVector<Value> &visitedVals);
FailureOr<SmallVector<Value>>
traverseForOpForDefs(scf::ForOp forOp, int argIdx,
SetVector<Value> &visitedVals) {
int iterArgIdx = argIdx - 1;
if (iterArgIdx >= 0) {
Value yieldVal = forOp.getBody()->getTerminator()->getOperand(iterArgIdx);
auto inLoop = traverseCFForValueDefs(yieldVal, visitedVals);
int forOpArgIdx = iterArgIdx + forOp.getNumControlOperands();
auto outLoop =
traverseCFForValueDefs(forOp.getOperand(forOpArgIdx), visitedVals);
if (failed(inLoop) || failed(outLoop))
return failure();
SmallVector<Value> foundDefs = inLoop.value();
foundDefs.append(outLoop.value());
return foundDefs;
} else {
auto search = traverseCFForValueDefs(forOp.getOperand(0), visitedVals);
if (failed(search))
return failure();
return search.value();
}
}
FailureOr<SmallVector<Value>>
traverseIfOpForDefs(scf::IfOp ifOp, int argIdx, SetVector<Value> &visitedVals) {
auto thenYield = ifOp.thenYield();
auto elseYield = ifOp.elseYield();
SmallVector<Value> foundDefs;
if (thenYield) {
auto ops =
traverseCFForValueDefs(thenYield->getOperand(argIdx), visitedVals);
if (failed(ops))
return failure();
foundDefs.append(ops.value());
}
if (elseYield) {
auto ops =
traverseCFForValueDefs(elseYield->getOperand(argIdx), visitedVals);
if (failed(ops))
return failure();
foundDefs.append(ops.value());
}
return foundDefs;
}
FailureOr<SmallVector<Value>>
traverseWhileOpForDefs(scf::WhileOp whileOp, int argIdx,
SetVector<Value> &visitedVals) {
auto terminator = whileOp.getYieldOp();
auto bodySearch =
traverseCFForValueDefs(terminator->getOperand(argIdx), visitedVals);
if (failed(bodySearch))
return failure();
SmallVector<Value> foundDefs = bodySearch.value();
auto initSearch =
traverseCFForValueDefs(whileOp.getInits()[argIdx], visitedVals);
if (failed(initSearch))
return failure();
foundDefs.append(initSearch.value());
return foundDefs;
}
FailureOr<SmallVector<Value>>
traverseRegionBranchOpForDefs(RegionBranchOpInterface regionBranch, int argIdx,
SetVector<Value> &visitedVals) {
llvm::SmallVector<scf::YieldOp> yieldOps;
regionBranch->walk([&](scf::YieldOp op) {
if (op->getParentOp() == regionBranch) {
yieldOps.push_back(op);
}
});
SmallVector<Value> foundDefs;
for (auto yieldOp : yieldOps) {
auto ops = traverseCFForValueDefs(yieldOp->getOperand(argIdx), visitedVals);
if (failed(ops))
return failure();
foundDefs.append(ops.value());
}
return foundDefs;
}
FailureOr<SmallVector<Value>>
traverseCFForValueDefs(Value val, SetVector<Value> &visitedVals) {
if (visitedVals.contains(val))
return SmallVector<Value>{};
visitedVals.insert(val);
LDBG(" traverseCFForValueDefs processing " << val);
auto attachValue =
[val](
FailureOr<SmallVector<Value>> res) -> FailureOr<SmallVector<Value>> {
if (failed(res))
return failure();
SmallVector<Value> result(std::move(res.value()));
result.push_back(val);
return result;
};
if (auto regionBranch = val.getDefiningOp<RegionBranchOpInterface>()) {
auto resId = cast<OpResult>(val).getResultNumber();
return attachValue(
traverseRegionBranchOpForDefs(regionBranch, resId, visitedVals));
}
if (!isa<BlockArgument>(val)) {
return SmallVector<Value>{val};
}
auto blockArg = dyn_cast<BlockArgument>(val);
Block *block = blockArg.getOwner();
Operation *parentOp = block->getParentOp();
if (!parentOp) {
LDBG(" block without parent op, can not analyze further");
return failure();
}
if (isa<tt::FuncOp>(parentOp)) {
LDBG(" can not traverse def-use chains, found function argument");
return failure();
}
int argIdx = blockArg.getArgNumber();
if (auto forOp = dyn_cast<scf::ForOp>(parentOp))
return attachValue(traverseForOpForDefs(forOp, argIdx, visitedVals));
if (auto ifOp = dyn_cast<scf::IfOp>(parentOp))
return attachValue(traverseIfOpForDefs(ifOp, argIdx, visitedVals));
if (auto whileOp = dyn_cast<scf::WhileOp>(parentOp))
return attachValue(traverseWhileOpForDefs(whileOp, argIdx, visitedVals));
LDBG(" can not traverse def-use chains, unsupported control flow "
"operation");
return failure();
}
struct ForwardSearchAnalysis {
SmallVector<Operation *> ops;
SmallVector<Value> transitiveCF;
};
FailureOr<ForwardSearchAnalysis>
traverseCFForValueUses(Value val, SetVector<Value> &visitedVals) {
if (visitedVals.contains(val))
return ForwardSearchAnalysis{};
visitedVals.insert(val);
LDBG(" traverseCFForValueUses processing " << val);
ForwardSearchAnalysis result;
for (auto &use : val.getUses()) {
auto user = use.getOwner();
LDBG(" processing user " << *user);
if (isa<tt::ReturnOp>(user)) {
LDBG(" Reached return from function");
return failure();
}
if (isa<scf::YieldOp>(user)) {
auto opIdx = use.getOperandNumber();
auto parent = user->getParentOp();
if (isa<scf::ForOp, scf::IfOp>(parent)) {
auto parentResult = parent->getResult(opIdx);
auto cfSearch = traverseCFForValueUses(parentResult, visitedVals);
if (failed(cfSearch))
return failure();
result.ops.append(cfSearch.value().ops);
result.transitiveCF.push_back(parentResult);
result.transitiveCF.append(cfSearch.value().transitiveCF);
}
if (auto forOp = dyn_cast<scf::ForOp>(parent)) {
int forBodyOperandIdx = opIdx + forOp.getNumInductionVars();
auto blockArg = forOp.getBody()->getArgument(forBodyOperandIdx);
auto cfSearch = traverseCFForValueUses(blockArg, visitedVals);
if (failed(cfSearch))
return failure();
result.ops.append(cfSearch.value().ops);
result.transitiveCF.push_back(blockArg);
result.transitiveCF.append(cfSearch.value().transitiveCF);
} else if (auto whileOp = dyn_cast<scf::WhileOp>(parent)) {
auto condBlockArg = whileOp.getBeforeArguments()[opIdx];
auto condBlockSearch =
traverseCFForValueUses(condBlockArg, visitedVals);
if (failed(condBlockSearch))
return failure();
result.ops.append(condBlockSearch.value().ops);
result.transitiveCF.push_back(condBlockArg);
result.transitiveCF.append(condBlockSearch.value().transitiveCF);
} else if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
} else {
LDBG(" Reached unsupported CF operation in forward CF traversal");
return failure();
}
continue;
}
if (auto forOp = dyn_cast<scf::ForOp>(user)) {
LDBG(" for op num operands: " << forOp.getNumOperands());
LDBG(" for op body num operands: "
<< forOp.getBody()->getNumArguments());
assert(use.getOperandNumber() >= forOp.getNumControlOperands());
int blockArgIdx = use.getOperandNumber() - forOp.getNumControlOperands() +
forOp.getNumInductionVars();
auto blockArg = forOp.getBody()->getArgument(blockArgIdx);
auto cfSearch = traverseCFForValueUses(blockArg, visitedVals);
if (failed(cfSearch))
return failure();
result.ops.append(cfSearch.value().ops);
result.transitiveCF.push_back(blockArg);
result.transitiveCF.append(cfSearch.value().transitiveCF);
continue;
}
if (auto whileOp = dyn_cast<scf::WhileOp>(user)) {
int blockArgIdx = use.getOperandNumber();
auto condArg = whileOp.getBeforeArguments()[blockArgIdx];
auto condSearch = traverseCFForValueUses(condArg, visitedVals);
if (failed(condSearch))
return failure();
result.ops.append(condSearch.value().ops);
result.transitiveCF.push_back(condArg);
result.transitiveCF.append(condSearch.value().transitiveCF);
continue;
}
if (auto condOp = dyn_cast<scf::ConditionOp>(user)) {
int argIdx = use.getOperandNumber() - 1;
auto whileOp = condOp.getParentOp<scf::WhileOp>();
if (!whileOp) {
LDBG(" can not traverse scf::ConditionOp successors");
return failure();
}
auto bodyArg = whileOp.getAfterArguments()[argIdx];
auto bodySearch = traverseCFForValueUses(bodyArg, visitedVals);
if (failed(bodySearch))
return failure();
result.ops.append(bodySearch.value().ops);
result.transitiveCF.push_back(bodyArg);
result.transitiveCF.append(bodySearch.value().transitiveCF);
auto whileRes = whileOp.getResult(argIdx);
auto whileSearch = traverseCFForValueUses(whileRes, visitedVals);
if (failed(whileSearch))
return failure();
result.ops.append(whileSearch.value().ops);
result.transitiveCF.push_back(whileRes);
result.transitiveCF.append(whileSearch.value().transitiveCF);
continue;
}
if (isa<scf::SCFDialect>(user->getDialect())) {
LDBG(" can not traverse def-use chains, unsupported control flow "
"operation");
return failure();
}
result.ops.push_back(user);
}
return result;
}
template <typename Op>
FailureOr<SmallVector<Op>> findAllDefiningOps(Value val) {
SetVector<Value> visitedVals;
auto candidates = traverseCFForValueDefs(val, visitedVals);
if (failed(candidates))
return failure();
SmallVector<Op> result;
for (auto candidateValue : candidates.value()) {
auto op = candidateValue.getDefiningOp();
if (!op)
continue;
if (auto typedOp = dyn_cast<Op>(op))
result.push_back(typedOp);
}
return result;
}
FailureOr<GlobalToSharedMemoryOpChain>
findReachableSMemOps(ttg::LocalLoadOp root) {
SetVector<Operation *> visitedOps;
SetVector<Value> visitedValsForward;
SetVector<Value> visitedValsBackward;
GlobalToSharedMemoryOpChain foundNetwork;
SmallVector<Operation *> traversalStep{root};
while (!traversalStep.empty()) {
LDBG("begin new step in smem op analysis");
SmallVector<Operation *> nextTraversalStep;
for (auto candidate : traversalStep) {
if (visitedOps.contains(candidate))
continue;
visitedOps.insert(candidate);
LDBG(" processing in smem op analysis: " << *candidate);
Value smemOperand;
Value smemOutput;
if (isa<ttg::LocalAllocOp>(candidate)) {
foundNetwork.localAllocStores.insert(candidate);
smemOutput = candidate->getResult(0);
} else if (isa<ttg::LocalStoreOp>(candidate)) {
foundNetwork.localAllocStores.insert(candidate);
smemOperand = candidate->getOperand(1);
} else if (isa<ttg::MemDescIndexOp>(candidate)) {
smemOutput = candidate->getResult(0);
smemOperand = candidate->getOperand(0);
} else if (isa<ttg::LocalLoadOp, ttg::LocalDeallocOp>(candidate)) {
smemOperand = candidate->getOperand(0);
} else if (isa<ttg::AsyncCopyGlobalToLocalOp,
tt::amdgpu::BufferLoadToLocalOp>(candidate)) {
LDBG(" skip because of direct-to-lds load");
return failure();
} else {
LDBG(" catched operation unrelated to shared memory" << *candidate);
assert(false && " catched operation unrelated to shared memory");
return failure();
}
if (smemOperand) {
auto backwardSearch =
traverseCFForValueDefs(smemOperand, visitedValsBackward);
if (failed(backwardSearch))
return failure();
for (auto def : backwardSearch.value()) {
foundNetwork.sharedMemVals.push_back(def);
if (Operation *op = def.getDefiningOp()) {
if (isa<ttg::MemDescIndexOp, ttg::LocalAllocOp>(op))
nextTraversalStep.push_back(op);
}
}
}
if (smemOutput) {
auto forwardSearch =
traverseCFForValueUses(smemOutput, visitedValsForward);
if (failed(forwardSearch))
return failure();
foundNetwork.sharedMemVals.append(forwardSearch.value().transitiveCF);
nextTraversalStep.append(forwardSearch.value().ops);
}
}
traversalStep = std::move(nextTraversalStep);
}
return foundNetwork;
}
unsigned getMaxSizePerThread(RankedTensorType type, int dimIdx) {
auto loadEnc = type.getEncoding();
auto blockedEnc = dyn_cast<ttg::BlockedEncodingAttr>(loadEnc);
if (!blockedEnc)
return 0;
auto lanes = blockedEnc.getThreadsPerWarp();
auto warps = blockedEnc.getWarpsPerCTA();
auto shape = type.getShape();
int maxSize = shape[dimIdx] / (lanes[dimIdx] * warps[dimIdx]);
return std::max(1, maxSize);
}
llvm::FailureOr<GlobalToSharedMemoryOpChain>
matchInThreadTransposePattern(ttg::LocalLoadOp lLoad) {
auto opTensorTy = cast<RankedTensorType>(lLoad.getType());
auto opEnc = opTensorTy.getEncoding();
auto opDotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(opEnc);
if (!opDotOpEnc)
return failure();
int kDimNum = opDotOpEnc.getOpIdx() == 0 ? 1 : 0;
if (!isa<ttg::AMDMfmaEncodingAttr, ttg::AMDWmmaEncodingAttr>(
opDotOpEnc.getParent())) {
LDBG("Operand's parent encoding is not MFMA");
return failure();
}
auto sharedMemSearch = findReachableSMemOps(lLoad);
if (failed(sharedMemSearch)) {
LDBG("Failed to traverse shared memmory operation network");
return failure();
}
auto pattern = sharedMemSearch.value();
if (pattern.localAllocStores.empty()) {
LDBG("Did not find local alloc or store operations");
return failure();
}
for (auto localMemStore : pattern.localAllocStores) {
LDBG("processing local mem store operation: " << *localMemStore);
if (localMemStore->getNumOperands() == 0)
continue;
Value loadCandidate = localMemStore->getOperand(0);
auto loadedEnc =
cast<RankedTensorType>(loadCandidate.getType()).getEncoding();
auto blockedEnc = dyn_cast<ttg::BlockedEncodingAttr>(loadedEnc);
if (!blockedEnc)
return failure();
auto order = blockedEnc.getOrder();
if (order[0] == kDimNum) {
return failure();
}
auto globalLoadSearch = findAllDefiningOps<tt::LoadOp>(loadCandidate);
if (failed(globalLoadSearch)) {
LDBG("Failed to traverse path to global loads");
return failure();
}
pattern.globalLoads.insert_range(globalLoadSearch.value());
}
LDBG("found global loads: " << pattern.globalLoads.size());
for (auto load : pattern.globalLoads)
LDBG(load);
LDBG("found local alloc stores: " << pattern.localAllocStores.size());
for (auto local : pattern.localAllocStores)
LDBG(*local);
LDBG("found shared mem values: " << pattern.sharedMemVals.size());
for (auto val : pattern.sharedMemVals)
LDBG(val);
if (pattern.globalLoads.empty()) {
LDBG("Did not find global load operation");
return failure();
}
auto firstLoadOp = pattern.globalLoads.front();
auto expectedLoadType =
cast<RankedTensorType>(firstLoadOp.getResult().getType());
if (expectedLoadType.getRank() != 2)
return failure();
auto kDimMaxSizePerThread = getMaxSizePerThread(expectedLoadType, kDimNum);
if (kDimMaxSizePerThread < 2) {
LDBG("Can not extend load layout");
return failure();
}
for (auto load : pattern.globalLoads) {
if (load->getResult(0).getType() != expectedLoadType) {
LDBG("Mismatch between global loads result types");
return failure();
}
}
return pattern;
}
ttg::BlockedEncodingAttr getTransposableBlockedEnc(int dotOperandIdx,
RankedTensorType loadType) {
auto shape = loadType.getShape();
int kDimNum = dotOperandIdx == 0 ? 1 : 0;
auto loadEnc = loadType.getEncoding();
auto blockedEnc = cast<ttg::BlockedEncodingAttr>(loadEnc);
auto maxkDimSizePerThread = getMaxSizePerThread(loadType, kDimNum);
auto elemBitwidth = loadType.getElementType().getIntOrFloatBitWidth();
const unsigned dsBitWidth = 64;
auto newKDimSize = std::min(maxkDimSizePerThread, dsBitWidth / elemBitwidth);
LDBG("Choose the minimum of numIters: " << newKDimSize << " and numElements: "
<< dsBitWidth / elemBitwidth);
SmallVector<unsigned> newSizePerThread{blockedEnc.getSizePerThread()};
newSizePerThread[kDimNum] = newKDimSize;
auto order = blockedEnc.getOrder();
auto ctx = blockedEnc.getContext();
auto numWarps = product(blockedEnc.getWarpsPerCTA());
auto threadsPerWarp = product(blockedEnc.getThreadsPerWarp());
auto numCTAs = product(blockedEnc.getCTALayout().getCTAsPerCGA());
return ttg::BlockedEncodingAttr::get(ctx, shape, newSizePerThread, order,
numWarps, threadsPerWarp, numCTAs);
}
class InThreadTransposePattern : public OpRewritePattern<ttg::LocalLoadOp> {
public:
InThreadTransposePattern(MLIRContext *context, PatternBenefit benefit = 1)
: OpRewritePattern(context, benefit) {}
LogicalResult matchAndRewrite(ttg::LocalLoadOp localLoad,
PatternRewriter &rewriter) const override {
LDBG("Consider " << localLoad);
auto matchResult = matchInThreadTransposePattern(localLoad);
if (!llvm::succeeded(matchResult)) {
LDBG("Failed to match InThreadTranspose pattern and nothing to be "
"done");
return failure();
}
auto pattern = matchResult.value();
auto dotOpEnc =
cast<ttg::DotOperandEncodingAttr>(localLoad.getType().getEncoding());
LDBG("Adjusting global loads");
auto firstLoadOp = pattern.globalLoads.front();
RankedTensorType loadResultType =
cast<RankedTensorType>(firstLoadOp.getResult().getType());
auto newBlockedEnc =
getTransposableBlockedEnc(dotOpEnc.getOpIdx(), loadResultType);
auto loadShape = loadResultType.getShape();
for (auto gLoad : pattern.globalLoads) {
LDBG("operand newBlockedEnc = " << newBlockedEnc);
refineGlobalLoadLayout(rewriter, newBlockedEnc, gLoad);
}
LDBG("Inserting transpose in registers before store in LDS");
for (auto memOp : pattern.localAllocStores)
transposeInRegsitersBeforeStoreInLocalMemory(rewriter, memOp, loadShape,
newBlockedEnc);
LDBG("Adjust shared encoding");
auto newSharedEncoding =
createNewSharedEncoding(cast<RankedTensorType>(localLoad.getType()));
for (auto memVal : pattern.sharedMemVals)
changeSharedEncoding(rewriter, memVal, newSharedEncoding);
return success();
}
};
}
class TritonAMDGPUInThreadTransposePass
: public impl::TritonAMDGPUInThreadTransposeBase<
TritonAMDGPUInThreadTransposePass> {
public:
void runOnOperation() override {
tt::FuncOp f = getOperation();
auto ctx = f.getContext();
RewritePatternSet patterns(ctx);
patterns.add<InThreadTransposePattern>(ctx, 1);
walkAndApplyPatterns(f, std::move(patterns));
}
};
}