#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/GenericSwizzling.h"
namespace mlir::triton {
static bool checkSquareSublayout(const LinearLayout &ll,
ArrayRef<StringAttr> dimNames,
function_ref<bool(int, int32_t)> checkBasis) {
if (dimNames.size() == 0) {
return true;
}
LinearLayout sl = ll.sublayout(dimNames, dimNames);
for (StringAttr dim : dimNames) {
if (ll.getInDimSize(dim) != ll.getOutDimSize(dim)) {
return false;
}
}
sl = sl.flattenIns().flattenOuts();
const auto &inDimBases = sl.getBases().begin()->second;
for (auto [b, basis] : llvm::enumerate(inDimBases)) {
if (!checkBasis(b, basis[0])) {
return false;
}
}
return true;
}
bool squareSublayoutIsIdentity(const LinearLayout &ll,
ArrayRef<StringAttr> dimNames) {
return checkSquareSublayout(
ll, dimNames, [](int b, int32_t basis) { return basis == (1 << b); });
}
LinearLayout
ensureLayoutNotLargerThan(const LinearLayout &layout,
const llvm::SmallDenseMap<StringAttr, int64_t> &shape,
bool broadcastRegisters) {
assert(shape.size() == layout.getNumOutDims());
if (shape.empty()) {
return layout;
}
MLIRContext *ctx = shape.begin()->first.getContext();
auto bases = layout.getBases();
auto kRegister = StringAttr::get(ctx, "register");
std::set<int32_t> broadcastedDims;
for (auto outDim : llvm::enumerate(layout.getOutDimNames())) {
auto outDimName = outDim.value();
int32_t actualSize = layout.getOutDimSize(outDimName);
int32_t desiredSize = shape.lookup(outDimName);
if (actualSize <= desiredSize) {
continue;
}
assert(actualSize % desiredSize == 0);
std::vector<std::tuple<StringAttr, int, int>> sortedBases;
for (auto [inDimName, basis] : bases) {
for (size_t basisIdx = 0; basisIdx < basis.size(); basisIdx++) {
auto outValue = basis[basisIdx][outDim.index()];
if (outValue == 0) {
continue;
}
assert(llvm::isPowerOf2_32(outValue));
sortedBases.emplace_back(inDimName, basisIdx, outValue);
}
}
llvm::sort(sortedBases,
[](auto a, auto b) { return std::get<2>(a) > std::get<2>(b); });
for (auto [inDimName, basisIdx, outValue] : sortedBases) {
if (actualSize <= desiredSize) {
break;
}
if (!broadcastRegisters && inDimName == kRegister) {
broadcastedDims.insert(basisIdx);
} else {
bases[inDimName][basisIdx][outDim.index()] = 0;
}
actualSize >>= 1;
}
}
if (!broadcastRegisters) {
std::vector<std::vector<int32_t>> newBasesRegister;
for (auto [idx, basis] : llvm::enumerate(bases[kRegister])) {
if (broadcastedDims.find(idx) == broadcastedDims.end()) {
newBasesRegister.push_back(std::move(basis));
}
}
bases[kRegister] = std::move(newBasesRegister);
}
auto outDims = layout.getOutDims();
for (auto &[outDim, outDimSize] : outDims) {
outDimSize = std::min<int32_t>(outDimSize, shape.lookup(outDim));
}
return LinearLayout(std::move(bases), std::move(outDims),
false);
}
LinearLayout ensureLayoutNotSmallerThan(
const LinearLayout &layout,
const llvm::SmallDenseMap<StringAttr, int64_t> &shape) {
assert(shape.size() == layout.getNumOutDims());
if (shape.empty()) {
return layout;
}
StringAttr kDim = *layout.getInDimNames().begin();
assert(kDim == "register" || kDim == "offset");
LinearLayout ret = layout;
for (StringAttr outDimName : layout.getOutDimNames()) {
int32_t actualSize = layout.getOutDimSize(outDimName);
int32_t desiredSize = shape.lookup(outDimName);
assert(actualSize > desiredSize || desiredSize % actualSize == 0);
ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName);
assert(ret.getOutDimSize(outDimName) >= desiredSize);
}
return ret;
}
SmallVector<StringAttr> standardOutDimNames(MLIRContext *ctx, int rank) {
SmallVector<StringAttr> ret;
for (int i = 0; i < rank; i++) {
ret.push_back(StringAttr::get(ctx, "dim" + llvm::Twine(i)));
}
return ret;
}
SmallVector<std::pair<StringAttr, int32_t>>
standardOutDimPairs(MLIRContext *ctx, ArrayRef<int64_t> dstShape) {
auto newRank = dstShape.size();
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
for (auto [dim, size] :
llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) {
newOutDims.emplace_back(dim, size);
}
return newOutDims;
}
LinearLayout identityStandardND(StringAttr inDimName, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
assert(shape.size() == order.size());
MLIRContext *ctx = inDimName.getContext();
auto rank = shape.size();
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
LinearLayout ret = LinearLayout::empty();
for (int i = 0; i < shape.size(); i++) {
int dim = order[i];
ret *= LinearLayout::identity1D(shape[dim], inDimName, outDimNames[dim]);
}
return ret;
}
LinearLayout zerosLike(const LinearLayout &layout) {
auto bases = layout.getBases();
for (auto &basis : bases) {
for (auto &vec : basis.second) {
for (auto &val : vec) {
val = 0;
}
}
}
SmallVector<std::pair<StringAttr, int32_t>> outDims;
for (auto outDim : layout.getOutDimNames()) {
outDims.emplace_back(outDim, layout.getOutDimSize(outDim));
}
return LinearLayout(std::move(bases), std::move(outDims),
false);
}
std::optional<ColumnAction> regPermForDivide(const LinearLayout &A,
const LinearLayout &B, bool left) {
assert(A.getNumInDims() != 0);
auto kReg = *A.getInDimNames().begin();
assert(kReg.str() == "register");
assert(B.getNumInDims() != 0);
assert(kReg == *B.getInDimNames().begin());
LinearLayout broadcast;
for (StringAttr out : A.getOutDimNames()) {
broadcast *= LinearLayout::identity1D(1, kReg, out);
}
auto BBroadcast = broadcast * B;
const auto &ARegBases = A.getBases().lookup(kReg);
const auto &BRegBases = BBroadcast.getBases().lookup(kReg);
llvm::DenseMap<StringAttr, unsigned> log2QuotSize;
for (StringAttr out : A.getOutDimNames()) {
log2QuotSize[out] =
A.getOutDimSizeLog2(out) - BBroadcast.getOutDimSizeLog2(out);
if (log2QuotSize[out] < 0)
return std::nullopt;
}
auto multiplyByTileSize =
[&](ArrayRef<int32_t> bBasis) -> std::vector<int32_t> {
std::vector<int32_t> result;
size_t idx = 0;
assert(bBasis.size() == A.getNumOutDims());
for (auto [dim, b] : llvm::zip(A.getOutDimNames(), bBasis)) {
result.push_back(b << log2QuotSize.lookup(dim));
}
return result;
};
SmallVector<size_t> bIndices;
SmallVector<bool> used(ARegBases.size(), false);
for (auto bB : BRegBases) {
bool found = false;
if (!left)
bB = multiplyByTileSize(bB);
for (size_t j = 0; j < ARegBases.size(); ++j) {
found = !used[j] && (ARegBases[j] == bB);
if (found) {
bIndices.push_back(j);
used[j] = true;
break;
}
}
if (!found)
return std::nullopt;
}
SmallVector<size_t> remainingIndices;
for (size_t i = 0; i < ARegBases.size(); ++i) {
if (!used[i])
remainingIndices.push_back(i);
}
SmallVector<size_t> permOrder = to_vector(llvm::concat<size_t>(
left ? bIndices : remainingIndices, left ? remainingIndices : bIndices));
return ColumnAction(permOrder, kReg, ARegBases.size());
}
ColumnAction actionRemoveBroadcastedRegs(const LinearLayout &layout) {
assert(layout.getNumInDims() != 0);
auto kReg = *layout.getInDimNames().begin();
assert(kReg.str() == "register");
const auto &bases = layout.getBases().lookup(kReg);
SmallVector<size_t> permOrder;
for (size_t i = 0; i < bases.size(); ++i) {
if (!llvm::all_of(bases[i], [](size_t x) { return x == 0; })) {
permOrder.push_back(i);
}
}
return ColumnAction(permOrder, kReg, bases.size());
}
std::pair<int64_t, ColumnAction>
actionAdditiveStrides(const LinearLayout &layout, const LinearLayout addrLayout,
uint64_t maskSpanOffsets) {
assert(layout.getNumInDims() != 0);
auto kReg = *layout.getInDimNames().begin();
assert(kReg.str() == "register");
auto kLane = StringAttr::get(kReg.getContext(), "lane");
auto kWarp = StringAttr::get(kReg.getContext(), "warp");
assert(layout.getNumOutDims() == 1);
uint32_t bits = maskSpanOffsets;
llvm::SetVector<uint32_t> tileBases;
for (auto bases : llvm::make_second_range(addrLayout.getBases())) {
for (auto basis : bases) {
bits |= basis[0];
tileBases.insert(basis[0]);
}
}
SmallVector<size_t> front, back;
for (auto [idx, basis] : llvm::enumerate(layout.getBases().lookup(kReg))) {
if ((basis[0] & bits) == 0 || tileBases.contains(basis[0])) {
front.push_back(idx);
} else {
back.push_back(idx);
}
}
auto permOrder = to_vector(llvm::concat<size_t>(front, back));
return {1 << front.size(),
ColumnAction(permOrder, kReg, layout.getInDimSizeLog2(kReg))};
}
SmallVector<Value> broadcastAs(const SmallVector<Value> &values,
const LinearLayout &layout) {
assert(layout.getNumInDims() != 0);
auto kReg = *layout.getInDimNames().begin();
assert(kReg.str() == "register");
uint32_t broadcastMask = layout.getFreeVariableMasks().lookup(kReg);
assert((layout.getInDimSize(kReg) / (1 << llvm::popcount(broadcastMask))) ==
values.size());
std::vector<std::vector<int32_t>> newBases;
int i = 0;
for (int j = 0; j < layout.getInDimSizeLog2(kReg); j++) {
if (broadcastMask & (1 << j)) {
newBases.push_back({0});
} else {
newBases.push_back({1 << i});
i++;
}
}
auto newLayout = LinearLayout({{kReg, std::move(newBases)}}, {kReg});
SmallVector<Value> ret;
ret.reserve(newLayout.getInDimSize(kReg));
for (int i = 0; i < newLayout.getInDimSize(kReg); i++) {
int32_t srcIdx = newLayout.apply({{kReg, i}}).begin()->second;
ret.push_back(values[srcIdx]);
}
return ret;
}
SmallVector<StringAttr> supremum(const SmallVector<StringAttr> &x,
const SmallVector<StringAttr> &y) {
llvm::SetVector<StringAttr> result;
DenseMap<StringAttr, int> posX, posY;
for (auto [idx, elem] : llvm::enumerate(x))
posX[elem] = idx;
for (auto [idx, elem] : llvm::enumerate(y))
posY[elem] = idx;
int i = 0, j = 0;
const int INF = std::numeric_limits<int>::max();
while (i < x.size() || j < y.size()) {
while (i < x.size() && result.contains(x[i]))
++i;
while (j < y.size() && result.contains(y[j]))
++j;
if (i >= x.size() && j >= y.size())
break;
if (i < x.size() && j < y.size() && x[i] == y[j]) {
if (posY[x[i]] < j)
llvm_unreachable("Supremum does not exist");
result.insert(x[i]);
++i, ++j;
continue;
}
int candX = INF, candY = INF;
if (i < x.size()) {
if (posY.count(x[i]) && posY[x[i]] >= j)
candX = posY[x[i]];
}
if (j < y.size()) {
if (posX.count(y[j]) && posX[y[j]] >= i)
candY = posX[y[j]];
}
if (i < x.size() && candX == INF) {
result.insert(x[i]);
++i;
continue;
}
if (j < y.size() && candY == INF) {
result.insert(y[j]);
++j;
continue;
}
if (candX <= candY) {
if (posY[x[i]] < j)
llvm_unreachable("Supremum does not exist");
result.insert(x[i]);
++i;
} else {
if (posX[y[j]] < i)
llvm_unreachable("Supremum does not exist");
result.insert(y[j]);
++j;
}
}
return to_vector(result);
}
LinearLayout reshapeLayout(MLIRContext *ctx, LinearLayout layout,
ArrayRef<int64_t> shape) {
int rank = shape.size();
auto srcOutDims = to_vector(layout.getOutDimNames());
std::reverse(srcOutDims.begin(), srcOutDims.end());
auto newOutDims = standardOutDimPairs(ctx, shape);
std::reverse(newOutDims.begin(), newOutDims.end());
return layout.transposeOuts(srcOutDims)
.reshapeOuts(newOutDims)
.transposeOuts(standardOutDimNames(ctx, rank));
}
LinearLayout transposeLinearLayout(LinearLayout layout, ArrayRef<int> order) {
auto namedBases = layout.getBases();
for (auto &bases : llvm::make_second_range(namedBases)) {
for (auto &b : bases) {
std::vector<int32_t> newB;
for (auto i : order) {
newB.push_back(b[i]);
}
b = std::move(newB);
}
}
return LinearLayout(std::move(namedBases),
to_vector(layout.getOutDimNames()));
}
}