#include "Utility.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
namespace mlir::LLVM::AMD {
using namespace mlir;
SmallVector<unsigned> getCTATileOrder(MLIRContext *ctx,
const triton::LinearLayout &layout) {
auto llEnc = triton::gpu::LinearEncodingAttr::get(ctx, layout);
auto regDim = StringAttr::get(ctx, "register");
auto &bases = layout.getBases().find(regDim)->second;
auto numCTAs = product(triton::gpu::getCTAsPerCGA(llEnc));
unsigned registersPerThreadPerCTA =
product(llEnc.basesPerDim(regDim, false)) / numCTAs;
unsigned startIndex =
static_cast<unsigned>(std::log2(registersPerThreadPerCTA));
llvm::SmallSetVector<unsigned, 8> order;
for (unsigned i = startIndex; i < bases.size(); ++i) {
auto range = llvm::make_range(bases[i].begin(), bases[i].end());
auto it = llvm::find_if(range, [](unsigned v) { return v != 0; });
if (it != bases[i].end())
order.insert(std::distance(bases[i].begin(), it));
}
for (unsigned dim : llEnc.getOrder())
order.insert(dim);
return order.takeVector();
}
}