#include "triton/Tools/LinearLayout.h"
#include <cstdint>
#include <set>
#include <vector>
#include "mlir/IR/BuiltinAttributes.h"
#include "third_party/f2reduce/f2reduce.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/StrUtil.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetOperations.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#define DEBUG_TYPE "linear_layout"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
#if defined(_MSC_VER) && !defined(__clang__)
#include <intrin.h>
static int __builtin_ctz(unsigned x) {
unsigned long r;
_BitScanForward(&r, x);
return static_cast<int>(r);
}
static int __builtin_ctzll(unsigned long long x) {
unsigned long r;
_BitScanForward64(&r, x);
return static_cast<int>(r);
}
#endif
namespace mlir::triton {
namespace {
using BasesT = LinearLayout::BasesT;
using llvm::SmallDenseSet;
using llvm::Twine;
BasesT makeBasesMap(
ArrayRef<std::pair<StringAttr, std::vector<std::vector<int32_t>>>> bases) {
BasesT ret;
for (const auto &[inDim, inDimBases] : bases) {
ret[inDim] = inDimBases;
}
return ret;
}
void dumpMatrix(uint64_t *m, int numRows, int numCols) {
assert(numCols <= 64);
for (int r = 0; r < numRows; r++) {
llvm::errs() << "0b";
for (int c = 0; c < numCols; c++) {
llvm::errs() << ((m[r] & (1 << c)) != 0 ? "1" : "0");
}
llvm::errs() << "\n";
}
}
std::unique_ptr<uint64_t[]> getMatrix(const LinearLayout &layout) {
int numRows = layout.getTotalOutDimSizeLog2();
int numCols = layout.getTotalInDimSizeLog2();
assert(numCols <= 64 && "LinearLayout too large");
assert(numRows <= 64 && "LinearLayout too large");
std::unique_ptr<uint64_t[]> m(new uint64_t[numRows]());
int r = 0;
for (StringAttr outDim : layout.getOutDimNames()) {
int c = 0;
for (StringAttr inDim : layout.getInDimNames()) {
for (int i = 0; i < layout.getInDimSizeLog2(inDim); i++) {
uint64_t basis = layout.getBasis(inDim, i, outDim);
for (int j = 0; j < layout.getOutDimSizeLog2(outDim); j++) {
m[r + j] |= ((basis >> j) & 1) << c;
}
c++;
}
}
r += layout.getOutDimSizeLog2(outDim);
}
return m;
}
int getMatrixRank(std::unique_ptr<uint64_t[]> m, int numRows, int numCols) {
assert(numCols <= 64);
f2reduce::inplace_rref_strided(m.get(), numRows, numCols, 1);
int rank = 0;
for (int i = 0; i < numRows; i++) {
if (m[i] != 0)
rank++;
}
return rank;
}
template <typename T, typename U>
void assertDimsEqualIgnoringOrder(T &&a, U &&b) {
SmallDenseSet<StringAttr> as(a.begin(), a.end());
SmallDenseSet<StringAttr> bs(b.begin(), b.end());
if (as != bs) {
llvm::report_fatal_error("Dimensions must match, ignoring order, but they "
"don't. Got dims: [" +
Twine(triton::join(a, ", ")) + "] and [" +
triton::join(b, ", ") + "]");
}
}
template <typename T, typename U>
void assertDimsSubsetIgnoringOrder(T &&small, U &&big) {
SmallDenseSet<StringAttr> smallSet(small.begin(), small.end());
SmallDenseSet<StringAttr> bigSet(big.begin(), big.end());
if (!llvm::set_is_subset(smallSet, bigSet)) {
llvm::report_fatal_error("Dimensions must be a subset, ignoring order, but "
"they aren't. Got dims: [" +
Twine(triton::join(small, ", ")) + "] and [" +
triton::join(big, ", ") + "]");
}
}
}
std::optional<LinearLayout>
LinearLayout::tryCreate(BasesT bases,
ArrayRef<std::pair<StringAttr, int32_t>> outDims,
bool requireSurjective) {
LinearLayout ll(std::move(bases), std::move(outDims), NoCheckInvariants{});
std::optional<std::string> error = ll.checkInvariants(requireSurjective);
if (error) {
return std::nullopt;
}
return ll;
}
LinearLayout::LinearLayout(BasesT bases,
ArrayRef<std::pair<StringAttr, int32_t>> outDims,
NoCheckInvariants)
: bases(std::move(bases)) {
for (auto [outDim, size] : outDims) {
this->outDims[outDim] = size;
}
}
LinearLayout::LinearLayout(BasesT bases, ArrayRef<StringAttr> outDimNames)
: bases(std::move(bases)) {
for (StringAttr outDim : outDimNames) {
outDims[outDim] = 1;
}
for (const auto &[inDim, inDimBases] : this->bases) {
for (const auto &basis : inDimBases) {
for (int i = 0; i < basis.size(); i++) {
int32_t &size = outDims[outDimNames[i]];
size = std::max<int32_t>(size, llvm::NextPowerOf2(basis[i]));
}
}
}
std::optional<std::string> error =
checkInvariants(true);
if (error.has_value()) {
llvm::report_fatal_error(StringRef(*error));
}
}
LinearLayout::LinearLayout(BasesT bases,
ArrayRef<std::pair<StringAttr, int32_t>> outDims,
bool requireSurjective)
: LinearLayout(std::move(bases), std::move(outDims), NoCheckInvariants{}) {
std::optional<std::string> error = checkInvariants(requireSurjective);
if (error.has_value()) {
llvm::report_fatal_error(StringRef(*error));
}
}
std::optional<std::string>
LinearLayout::checkInvariants(bool requireSurjective) {
LDBG("checkInvariants: " << toString());
for (const auto &[inDim, inDimBases] : bases) {
for (const auto &basis : inDimBases) {
if (llvm::any_of(basis, [](int32_t b) { return b < 0; })) {
return "Invalid bases passed to LinearLayout. Expected all basis "
"values to be non-negative, but found a negative value for "
"in dimension '" +
inDim.str() + "'. Full list of bases:" + toString() + "\n";
}
}
}
for (const auto &[inDim, inDimBases] : bases) {
for (const auto &basis : inDimBases) {
if (basis.size() != outDims.size()) {
return "Invalid bases passed to LinearLayout. Expect all bases to "
"have the same size, equal to outDimNames.size() (" +
std::to_string(outDims.size()) +
"). But this failed for in dimension '" + inDim.str() +
"'. Full list of bases:" + toString() + "\n";
}
}
}
for (const auto &[outDim, size] : outDims) {
if (!llvm::isPowerOf2_32(size)) {
return "Invalid out-dim size " + std::to_string(size) + " for out-dim '" +
outDim.str() + "'. Out-dim sizes must be powers of 2.\n";
}
}
SmallVector<StringAttr> outDimNames = llvm::to_vector(getOutDimNames());
for (const auto &[inDim, inDimBases] : this->bases) {
for (const auto &basis : inDimBases) {
for (int i = 0; i < basis.size(); i++) {
if (basis[i] >= outDims[outDimNames[i]]) {
return "Invalid basis " + std::to_string(basis[i]) + " for in-dim '" +
inDim.str() + "' and out-dim '" + outDimNames[i].str() +
"'. Basis must be less than the out-dim size.\n";
}
}
}
}
this->rank =
getMatrixRank(getMatrix(*this), getTotalOutDimSizeLog2(),
getTotalInDimSizeLog2());
if (requireSurjective && !isSurjective()) {
return "Layout is expected to be surjective, i.e. every `out` coordinate "
"can be reached by some `in` coordinate, but was not:" +
toString();
}
return std::nullopt;
}
LinearLayout::LinearLayout(
ArrayRef<std::pair<StringAttr, std::vector<std::vector<int32_t>>>> bases,
ArrayRef<StringAttr> outDimNames)
: LinearLayout(makeBasesMap(bases), outDimNames) {}
LinearLayout::LinearLayout(
ArrayRef<std::pair<StringAttr, std::vector<std::vector<int32_t>>>> bases,
ArrayRef<std::pair<StringAttr, int32_t>> outDims, bool requireSurjective)
: LinearLayout(makeBasesMap(bases), outDims, requireSurjective) {}
LinearLayout LinearLayout::strided1D(int32_t size, int32_t stride,
StringAttr inDimName,
StringAttr outDimName) {
if (size == 0)
return LinearLayout::empty();
assert(llvm::isPowerOf2_32(size));
std::vector<std::vector<int32_t>> bases;
for (int32_t i = 1; i < size; i *= 2) {
bases.emplace_back(std::vector<int32_t>{i * stride});
}
bool requiresSurjective = (stride == 1);
return LinearLayout({{inDimName, std::move(bases)}},
{{outDimName, stride * size}}, requiresSurjective);
}
LinearLayout LinearLayout::zeros1D(int32_t size,
StringAttr inDimName,
StringAttr outDimName,
int32_t outDimSize) {
if (size == 0)
return LinearLayout::empty();
assert(llvm::isPowerOf2_32(size));
std::vector<std::vector<int32_t>> zeros;
for (int i = 1; i < size; i *= 2) {
zeros.emplace_back(std::vector<int32_t>{0});
}
return LinearLayout({{inDimName, zeros}}, {{outDimName, outDimSize}},
outDimSize == 1);
}
int32_t LinearLayout::getOutDimIndex(StringAttr outDim) const {
int i = 0;
for (auto [name, _] : outDims) {
if (name == outDim) {
return i;
}
i++;
}
llvm::report_fatal_error("outDim " + Twine(outDim) + " is not in layout" +
toString());
}
int32_t LinearLayout::getInDimSizeLog2(StringAttr inDim) const {
auto it = bases.find(inDim);
assert(it != bases.end());
return it->second.size();
}
int32_t LinearLayout::getTotalInDimSizeLog2() const {
return std::accumulate(getInDimNames().begin(), getInDimNames().end(), 0,
[&](int32_t acc, StringAttr inDim) {
return acc + getInDimSizeLog2(inDim);
});
}
int32_t LinearLayout::getOutDimSizeLog2(StringAttr outDim) const {
auto it = outDims.find(outDim);
assert(it != outDims.end());
return llvm::Log2_32(it->second);
}
int32_t LinearLayout::getTotalOutDimSizeLog2() const {
return std::accumulate(getOutDimNames().begin(), getOutDimNames().end(), 0,
[&](int32_t acc, StringAttr outDim) {
return acc + getOutDimSizeLog2(outDim);
});
}
int32_t LinearLayout::getNumConsecutiveInOut() const {
if (bases.empty() || getNumOutDims() == 0)
return 1;
const auto &firstInDimBases = bases.begin()->second;
int consec = 0;
for (; consec < firstInDimBases.size(); consec++) {
const auto &basis = firstInDimBases[consec];
if (basis[0] != (1 << consec) ||
!std::all_of(basis.begin() + 1, basis.end(),
[](int32_t x) { return x == 0; })) {
break;
}
}
int32_t otherBits = 0;
for (const auto &[inDim, inDimBases] : bases) {
for (int i = 0; i < inDimBases.size(); i++) {
if (inDim != bases.begin()->first || i >= consec) {
otherBits |= inDimBases[i][0];
}
}
}
int32_t trailingZeros = otherBits != 0 ? __builtin_ctz(otherBits) : 31;
return 1 << std::min(consec, trailingZeros);
}
LinearLayout LinearLayout::transposeIns(ArrayRef<StringAttr> newInDims) const {
assertDimsEqualIgnoringOrder(newInDims, getInDimNames());
BasesT newBases;
for (const auto &inDim : newInDims) {
newBases[inDim] = bases.find(inDim)->second;
}
return LinearLayout(std::move(newBases), llvm::to_vector(outDims),
isSurjective());
}
LinearLayout
LinearLayout::transposeOuts(ArrayRef<StringAttr> newOutDims) const {
assertDimsEqualIgnoringOrder(newOutDims, getOutDimNames());
std::vector<int32_t> permutation;
for (const auto &outDim : newOutDims) {
permutation.push_back(getOutDimIndex(outDim));
}
BasesT newBases;
for (const auto &[inDim, inDimBases] : bases) {
auto &newInDimBases = newBases[inDim];
for (const auto &basis : inDimBases) {
std::vector<int32_t> newBasis;
for (int32_t i : permutation) {
newBasis.push_back(basis[i]);
}
newInDimBases.push_back(std::move(newBasis));
}
}
SmallVector<std::pair<StringAttr, int32_t>> newOutDimSizes;
for (auto outDim : newOutDims) {
newOutDimSizes.push_back({outDim, getOutDimSize(outDim)});
}
return LinearLayout(std::move(newBases), newOutDimSizes, isSurjective());
}
LinearLayout LinearLayout::reshapeIns(
ArrayRef<std::pair<StringAttr, int32_t>> newInDims) const {
assert(llvm::all_of(newInDims, [&](auto &inDim) {
return llvm::isPowerOf2_32(inDim.second);
}));
assert(getTotalInDimSize() == std::accumulate(newInDims.begin(),
newInDims.end(), 1,
[&](int32_t acc, auto &inDim) {
return acc * inDim.second;
}));
SmallVector<std::vector<int32_t>> flatBases;
for (const auto &[inDim, inDimBases] : bases) {
for (const auto &basis : inDimBases) {
flatBases.push_back(basis);
}
}
BasesT newBases;
int i = 0;
for (const auto &[inDim, inDimSize] : newInDims) {
auto &newInDimBases = newBases[inDim];
for (int j = 1; j < inDimSize; j *= 2) {
newInDimBases.push_back(flatBases[i++]);
}
}
return LinearLayout(std::move(newBases), llvm::to_vector(outDims),
isSurjective());
}
LinearLayout LinearLayout::reshapeOuts(
ArrayRef<std::pair<StringAttr, int32_t>> newOutDims) const {
assert(llvm::all_of(newOutDims, [&](auto &outDim) {
return llvm::isPowerOf2_32(outDim.second);
}));
assert(getTotalOutDimSize() ==
std::accumulate(
newOutDims.begin(), newOutDims.end(), 1,
[&](int32_t acc, auto &outDim) { return acc * outDim.second; }));
SmallVector<int32_t> shifts;
shifts.push_back(0);
for (StringAttr outDim : getOutDimNames()) {
shifts.push_back(shifts.back() + getOutDimSizeLog2(outDim));
}
llvm::MapVector<StringAttr, std::vector<int32_t>> flatBases;
for (const auto &[inDim, inDimBases] : bases) {
auto &flatInBases = flatBases[inDim];
for (const auto &basis : inDimBases) {
int b = 0;
for (int i = 0; i < basis.size(); i++) {
b += basis[i] << shifts[i];
}
flatInBases.push_back(b);
}
}
BasesT newBases;
for (const auto &[inDim, flatInBases] : flatBases) {
std::vector<std::vector<int32_t>> &newInDimBases = newBases[inDim];
for (int32_t b : flatInBases) {
std::vector<int32_t> multiDimBasis;
for (int32_t newSize : llvm::make_second_range(newOutDims)) {
multiDimBasis.push_back(b % newSize);
b /= newSize;
}
newInDimBases.push_back(std::move(multiDimBasis));
}
}
return LinearLayout(std::move(newBases), newOutDims, isSurjective());
}
LinearLayout LinearLayout::concatIns(const LinearLayout &other) const {
assert(llvm::to_vector(getOutDimNames()) ==
llvm::to_vector(other.getOutDimNames()) &&
"layouts must have the same output dimensions");
for (StringAttr outDim : getOutDimNames()) {
assert(getOutDimSize(outDim) == other.getOutDimSize(outDim) &&
"layouts must have the same output dimension sizes");
}
LinearLayout::BasesT resultBases = getBases();
for (auto &bases : other.getBases())
resultBases.insert(bases);
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
for (auto &[outDim, outDimSize] : outDims)
newOutDims.emplace_back(outDim, outDimSize);
return LinearLayout(std::move(resultBases), newOutDims,
false);
}
LinearLayout LinearLayout::concatOuts(const LinearLayout &other) const {
assert(llvm::to_vector(getInDimNames()) ==
llvm::to_vector(other.getInDimNames()) &&
"layouts must have the same input dimensions");
for (StringAttr inDim : getInDimNames()) {
assert(getInDimSize(inDim) == other.getInDimSize(inDim) &&
"layouts must have the same input dimension sizes");
}
LinearLayout::BasesT result;
for (auto [lhsBases, rhsBases] : llvm::zip(getBases(), other.getBases())) {
auto &resultBases = result[lhsBases.first];
assert(lhsBases.first == rhsBases.first);
for (auto [lhsBasis, rhsBasis] :
llvm::zip(lhsBases.second, rhsBases.second)) {
std::vector<int32_t> resultBasis;
llvm::append_range(resultBasis, lhsBasis);
llvm::append_range(resultBasis, rhsBasis);
resultBases.push_back(std::move(resultBasis));
}
}
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
for (auto &[outDim, outDimSize] : outDims)
newOutDims.emplace_back(outDim, outDimSize);
for (auto &[outDim, outDimSize] : other.outDims)
newOutDims.emplace_back(outDim, outDimSize);
return LinearLayout(std::move(result), newOutDims,
false);
}
std::optional<LinearLayout> divideLeft(const LinearLayout &A,
const LinearLayout &B) {
for (StringAttr dim : B.getInDimNames()) {
if (!llvm::is_contained(A.getInDimNames(), dim))
return std::nullopt;
}
for (StringAttr dim : B.getOutDimNames()) {
if (!llvm::is_contained(A.getOutDimNames(), dim))
return std::nullopt;
}
llvm::MapVector<StringAttr, int32_t> cOutDimSizes;
for (StringAttr outDim : A.getOutDimNames()) {
int outA = A.getOutDimSizeLog2(outDim);
int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0;
int outC = outA - outB;
if (outC < 0)
return std::nullopt;
cOutDimSizes[outDim] = 1 << outC;
}
LinearLayout::BasesT cBases;
for (StringAttr inDim : A.getInDimNames()) {
int inA = A.getInDimSizeLog2(inDim);
int inB = B.hasInDim(inDim) ? B.getInDimSizeLog2(inDim) : 0;
int inC = inA - inB;
if (inC < 0)
return std::nullopt;
std::vector<std::vector<int32_t>> basesForDim;
for (int i = 0; i < inB; ++i) {
for (StringAttr outDim : A.getOutDimNames()) {
int expected = B.hasOutDim(outDim) ? B.getBasis(inDim, i, outDim) : 0;
int actual = A.getBasis(inDim, i, outDim);
if (actual != expected)
return std::nullopt;
}
}
for (int i = inB; i < inA; ++i) {
std::vector<int32_t> candidateBasis;
for (StringAttr outDim : llvm::make_first_range(cOutDimSizes)) {
int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0;
int v = A.getBasis(inDim, i, outDim);
if ((v & ((1 << outB) - 1)) != 0)
return std::nullopt;
candidateBasis.push_back(v >> outB);
}
basesForDim.push_back(std::move(candidateBasis));
}
cBases[inDim] = basesForDim;
}
SmallVector<std::pair<StringAttr, int32_t>> COutDims;
for (auto [outDim, outC] : cOutDimSizes) {
COutDims.push_back({outDim, outC});
}
LinearLayout C(std::move(cBases), COutDims,
A.isSurjective() && B.isSurjective());
assert(B * C == A);
return C;
}
std::optional<LinearLayout> divideRight(const LinearLayout &A,
const LinearLayout &B) {
for (StringAttr dim : B.getInDimNames()) {
if (!llvm::is_contained(A.getInDimNames(), dim))
return std::nullopt;
}
for (StringAttr dim : B.getOutDimNames()) {
if (!llvm::is_contained(A.getOutDimNames(), dim))
return std::nullopt;
}
llvm::MapVector<StringAttr, int32_t> cOutDimSizes;
for (StringAttr outDim : A.getOutDimNames()) {
int outA = A.getOutDimSizeLog2(outDim);
int outB = B.hasOutDim(outDim) ? B.getOutDimSizeLog2(outDim) : 0;
int outC = outA - outB;
if (outC < 0)
return std::nullopt;
cOutDimSizes[outDim] = 1 << outC;
}
LinearLayout::BasesT cBases;
for (StringAttr inDim : A.getInDimNames()) {
int inA = A.getInDimSizeLog2(inDim);
int inB = B.hasInDim(inDim) ? B.getInDimSizeLog2(inDim) : 0;
int inC = inA - inB;
if (inC < 0)
return std::nullopt;
std::vector<std::vector<int32_t>> basesForDim;
for (int i = 0; i < inC; ++i) {
std::vector<int32_t> candidate;
for (StringAttr outDim : llvm::make_first_range(cOutDimSizes)) {
candidate.push_back(A.getBasis(inDim, i, outDim));
}
basesForDim.push_back(std::move(candidate));
}
for (int i = inC; i < inA; ++i) {
int j = i - inC;
for (StringAttr outDim : B.getOutDimNames()) {
int outA = A.getOutDimSizeLog2(outDim);
int outB = B.getOutDimSizeLog2(outDim);
int outC = outA - outB;
int shift = outC;
int v = A.getBasis(inDim, i, outDim);
if ((v & ((1 << shift) - 1)) != 0)
return std::nullopt;
int recovered = v >> shift;
int expected = B.getBasis(inDim, j, outDim);
if (recovered != expected)
return std::nullopt;
}
}
cBases[inDim] = basesForDim;
}
SmallVector<std::pair<StringAttr, int32_t>> COutDims;
for (auto [outDim, size] : cOutDimSizes)
COutDims.push_back({outDim, size});
LinearLayout C(std::move(cBases), COutDims,
A.isSurjective() && B.isSurjective());
assert(C * B == A);
return C;
}
LinearLayout operator*(LinearLayout inner, LinearLayout outer) {
auto inDims = supremum(llvm::to_vector(inner.getInDimNames()),
llvm::to_vector(outer.getInDimNames()));
auto outDims = supremum(llvm::to_vector(inner.getOutDimNames()),
llvm::to_vector(outer.getOutDimNames()));
llvm::MapVector<StringAttr, int32_t> inDimSizesLog2;
llvm::MapVector<StringAttr, int32_t> outDimSizesLog2;
for (const auto &dim : inDims)
inDimSizesLog2.insert({dim, 0});
for (const auto &dim : outDims)
outDimSizesLog2.insert({dim, 0});
for (const auto &layout : {inner, outer}) {
for (StringAttr inDim : layout.getInDimNames()) {
inDimSizesLog2[inDim] += layout.getInDimSizeLog2(inDim);
}
for (StringAttr outDim : layout.getOutDimNames()) {
outDimSizesLog2[outDim] += layout.getOutDimSizeLog2(outDim);
}
}
BasesT allBases;
for (auto [inDimName, inDimSizeLog2] : inDimSizesLog2) {
std::vector<std::vector<int32_t>> &inDimBases = allBases[inDimName];
inDimBases = std::vector<std::vector<int32_t>>(
inDimSizeLog2, std::vector<int32_t>(outDimSizesLog2.size(), 0));
for (auto [outDimIdx, outDimNameAndSize] :
llvm::enumerate(outDimSizesLog2)) {
auto [outDimName, outDimSize] = outDimNameAndSize;
if (inner.hasInDim(inDimName) && inner.hasOutDim(outDimName)) {
for (int i = 0; i < inner.getInDimSizeLog2(inDimName); i++) {
inDimBases[i][outDimIdx] = inner.getBasis(inDimName, i, outDimName);
}
}
if (outer.hasInDim(inDimName) && outer.hasOutDim(outDimName)) {
int offset =
inner.hasInDim(inDimName) ? inner.getInDimSizeLog2(inDimName) : 0;
int shift = inner.hasOutDim(outDimName)
? inner.getOutDimSizeLog2(outDimName)
: 0;
for (int i = 0; i < outer.getInDimSizeLog2(inDimName); i++) {
inDimBases[offset + i][outDimIdx] =
outer.getBasis(inDimName, i, outDimName) << shift;
}
}
}
}
llvm::SmallVector<std::pair<StringAttr, int32_t>> outDimSizes;
for (auto [outDim, sizeLog2] : outDimSizesLog2) {
outDimSizes.push_back({outDim, 1 << sizeLog2});
}
return LinearLayout(std::move(allBases), outDimSizes,
inner.isSurjective() && outer.isSurjective());
}
bool LinearLayout::isTrivialOver(ArrayRef<StringAttr> dimNames) const {
for (StringAttr dim : dimNames) {
if (!llvm::is_contained(getInDimNames(), dim) &&
!llvm::is_contained(getOutDimNames(), dim)) {
return false;
}
}
auto getRemainingDimNames = [&](auto allDimNames) {
SmallVector<StringAttr> remainingDimNames;
for (StringAttr dim : allDimNames) {
if (!llvm::is_contained(dimNames, dim)) {
remainingDimNames.push_back(dim);
}
}
return remainingDimNames;
};
SmallVector<StringAttr> remainingInDimNames =
getRemainingDimNames(getInDimNames());
SmallVector<StringAttr> remainingOutDimNames =
getRemainingDimNames(getOutDimNames());
return squareSublayoutIsIdentity(*this, dimNames) &&
sublayoutIsZero(remainingInDimNames, dimNames) &&
sublayoutIsZero(dimNames, remainingOutDimNames);
}
std::optional<LinearLayout>
LinearLayout::quotient(ArrayRef<StringAttr> dimNames) const {
if (!isTrivialOver(dimNames)) {
return std::nullopt;
}
auto getRemainingDimNames = [&](auto allDimNames) {
SmallVector<StringAttr> remainingDimNames;
for (StringAttr dim : allDimNames) {
if (!llvm::is_contained(dimNames, dim)) {
remainingDimNames.push_back(dim);
}
}
return remainingDimNames;
};
SmallVector<StringAttr> inDimNames = getRemainingDimNames(getInDimNames());
SmallVector<StringAttr> outDimNames = getRemainingDimNames(getOutDimNames());
return sublayout(inDimNames, outDimNames);
}
LinearLayout LinearLayout::sublayout(ArrayRef<StringAttr> inDimNames,
ArrayRef<StringAttr> outDimNames) const {
assertDimsSubsetIgnoringOrder(inDimNames, getInDimNames());
assertDimsSubsetIgnoringOrder(outDimNames, getOutDimNames());
SmallDenseSet<StringAttr> inDimSet(inDimNames.begin(), inDimNames.end());
SmallDenseSet<StringAttr> outDimSet(outDimNames.begin(), outDimNames.end());
SmallVector<int> outDimIndicesToKeep;
for (auto [i, outDim] : llvm::enumerate(getOutDimNames())) {
if (outDimSet.contains(outDim)) {
outDimIndicesToKeep.push_back(i);
}
}
BasesT newBases;
for (auto [inDim, inDimBases] : bases) {
if (!inDimSet.contains(inDim)) {
continue;
}
auto &newInDimBases = newBases[inDim];
for (auto &basis : inDimBases) {
auto &newBasis = newInDimBases.emplace_back();
for (int i : outDimIndicesToKeep) {
newBasis.push_back(basis[i]);
}
}
}
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
for (auto [outDim, outDimSize] : outDims) {
if (outDimSet.contains(outDim)) {
newOutDims.push_back({outDim, outDimSize});
}
}
return LinearLayout(std::move(newBases), std::move(newOutDims),
false);
}
bool LinearLayout::sublayoutIsZero(ArrayRef<StringAttr> inDimNames,
ArrayRef<StringAttr> outDimNames) const {
LinearLayout ss = sublayout(inDimNames, outDimNames);
for (auto [inDim, inDimBases] : ss.bases) {
for (auto basis : inDimBases) {
if (!llvm::all_of(basis, [](int32_t b) { return b == 0; })) {
return false;
}
}
}
return true;
}
SmallVector<std::pair<StringAttr, int32_t>>
LinearLayout::apply(ArrayRef<std::pair<StringAttr, int32_t>> ins) const {
assertDimsEqualIgnoringOrder(llvm::make_first_range(ins), getInDimNames());
SmallVector<std::pair<StringAttr, int32_t>> ret;
for (StringAttr outDim : getOutDimNames()) {
int32_t outVal = 0;
for (auto &[inDim, val] : ins) {
for (int i = 0; i < getInDimSizeLog2(inDim); i++) {
if (val & (1 << i))
outVal ^= getBasis(inDim, i, outDim);
}
}
ret.push_back({outDim, outVal});
}
return ret;
}
LinearLayout LinearLayout::compose(const LinearLayout &outer) const {
assertDimsEqualIgnoringOrder(getOutDimNames(), outer.getInDimNames());
for (StringAttr outDim : getOutDimNames()) {
assert(getOutDimSize(outDim) <= outer.getInDimSize(outDim));
}
BasesT newBases;
for (const auto &[inDim, inDimBases] : bases) {
auto &newInDimBases = newBases[inDim];
for (const auto &basis : inDimBases) {
SmallVector<std::pair<StringAttr, int32_t>> bases;
for (auto [outDim, b] : llvm::zip(getOutDimNames(), basis)) {
bases.push_back({outDim, b});
}
auto newBases = outer.apply(bases);
auto newBasesRange = llvm::make_second_range(newBases);
newInDimBases.push_back(
std::vector<int32_t>(newBasesRange.begin(), newBasesRange.end()));
}
}
bool compositionIsSurjective =
isSurjective() && outer.isSurjective() &&
llvm::all_of(getOutDimNames(), [&](StringAttr outDim) {
return getOutDimSize(outDim) == outer.getInDimSize(outDim);
});
return LinearLayout(std::move(newBases), llvm::to_vector(outer.outDims),
compositionIsSurjective);
}
namespace {
std::unique_ptr<uint64_t[]> concatMatrices(const LinearLayout &A,
const LinearLayout &B) {
assert(A.getTotalOutDimSizeLog2() >= B.getTotalOutDimSizeLog2() &&
"A must have at least as many output bits as B");
int numColsA = A.getTotalInDimSizeLog2();
auto concat = getMatrix(A);
auto BMat = getMatrix(B);
int rowA = 0;
int rowB = 0;
for (auto [outDim, outDimSize] : A.getOutDims()) {
for (int r = 0; r < llvm::Log2_32(outDimSize); r++) {
if (r < llvm::Log2_32(B.getOutDimSize(outDim))) {
concat[rowA] |= BMat[rowB] << numColsA;
rowB++;
}
rowA++;
}
}
return concat;
}
LinearLayout lstsq(const LinearLayout &A, const LinearLayout &B) {
int numRows = A.getTotalOutDimSizeLog2();
assert(numRows >= B.getTotalOutDimSizeLog2() &&
"A.lstsq(B) called with incompatible output shapes");
int numColsA = A.getTotalInDimSizeLog2();
int numColsB = B.getTotalInDimSizeLog2();
int numCols = numColsA + numColsB;
std::unique_ptr<uint64_t[]> combinedMat = concatMatrices(A, B);
f2reduce::inplace_rref_strided(combinedMat.get(), numRows, numCols,
1);
SmallVector<int32_t> pivotRowOfCol(numColsA, -1);
for (int r = 0; r < numRows; r++) {
auto row = combinedMat[r];
if (row == 0) {
continue;
}
int c = __builtin_ctzll(row);
assert(c < numColsA && "Precondition broken. Im(B) not contained in Im(A)");
assert(pivotRowOfCol[c] == -1 &&
"duplicate pivot => matrix not in RREF or A not injective");
pivotRowOfCol[c] = r;
}
std::unique_ptr<uint64_t[]> retMat(new uint64_t[numColsA]());
for (int c = 0; c < numColsA; ++c) {
int row = pivotRowOfCol[c];
retMat[c] = (row == -1) ? 0 : (combinedMat[row] >> numColsA);
}
assert(!A.getInDimNames().empty() &&
"attempt to solve lstsq for empty layout");
StringAttr inDim1D = *A.getInDimNames().begin();
StringAttr outDim1D = *A.getOutDimNames().begin();
LinearLayout::BasesT retBases;
auto &bs = retBases[inDim1D];
for (int c = 0; c < numColsB; c++) {
int32_t basis = 0;
for (int r = 0; r < numColsA; r++) {
basis |= (retMat[r] >> c & 1) << r;
}
bs.push_back({basis});
}
LinearLayout retFlattened(std::move(retBases),
{{outDim1D, A.getTotalInDimSize()}},
false);
SmallVector<std::pair<StringAttr, int32_t>> retInDims;
SmallVector<std::pair<StringAttr, int32_t>> retOutDims;
for (StringAttr dim : B.getInDimNames()) {
retInDims.push_back({dim, B.getInDimSize(dim)});
}
for (StringAttr dim : A.getInDimNames()) {
retOutDims.push_back({dim, A.getInDimSize(dim)});
}
return retFlattened.reshapeIns(retInDims).reshapeOuts(retOutDims);
}
}
LinearLayout LinearLayout::invertAndCompose(const LinearLayout &outer) const {
auto outDims = llvm::to_vector(getOutDimNames());
assertDimsEqualIgnoringOrder(outDims, outer.getOutDimNames());
const auto &B = *this;
const auto A = outer.transposeOuts(outDims);
for (auto dim : outDims) {
assert(A.getOutDimSize(dim) >= B.getOutDimSize(dim) &&
("A.invertAndCompose(B) called with incompatible output shapes in " +
dim.str() + ": " + std::to_string(A.getOutDimSize(dim)) +
" >= " + std::to_string(B.getOutDimSize(dim)))
.c_str());
}
SmallVector<StringAttr> identityDims;
for (auto dim : A.getInDimNames()) {
if (B.hasInDim(dim) &&
A.sublayout(dim, outDims) == B.sublayout(dim, outDims)) {
identityDims.push_back(dim);
}
}
SmallVector<StringAttr> ANonIdentityInDims;
SmallVector<StringAttr> BNonIdentityInDims;
for (auto dim : A.getInDimNames()) {
if (!llvm::is_contained(identityDims, dim)) {
ANonIdentityInDims.push_back(dim);
}
}
for (auto dim : B.getInDimNames()) {
if (!llvm::is_contained(identityDims, dim)) {
BNonIdentityInDims.push_back(dim);
}
}
auto AReduced = A.sublayout(ANonIdentityInDims, outDims);
auto BReduced = B.sublayout(BNonIdentityInDims, outDims);
assert((ANonIdentityInDims.empty()) == (BNonIdentityInDims.empty()));
bool isEmpty = ANonIdentityInDims.empty();
auto ret = isEmpty ? LinearLayout::empty() : lstsq(AReduced, BReduced);
for (auto dim : identityDims) {
ret *= LinearLayout::identity1D(A.getInDimSize(dim), dim, dim);
}
return ret.transposeIns(llvm::to_vector(B.getInDimNames()))
.transposeOuts(llvm::to_vector(A.getInDimNames()));
}
LinearLayout LinearLayout::invert() const {
assert(isInvertible() &&
"A linear layout must be surjective and square to be invertible");
return pseudoinvert();
}
LinearLayout LinearLayout::pseudoinvert() const {
LinearLayout identity = LinearLayout::empty();
for (auto outDim : getOutDimNames()) {
identity *= LinearLayout::identity1D(getOutDimSize(outDim), outDim, outDim);
}
return identity.invertAndCompose(*this);
}
LinearLayout LinearLayout::unsqueezeIn(StringAttr dim) const {
assert(getInDimSize(dim) == 1);
SmallVector<std::pair<StringAttr, int32_t>> newInDims;
for (auto inDim : getInDimNames()) {
if (inDim != dim) {
newInDims.push_back({inDim, getInDimSize(inDim)});
}
}
return reshapeIns(newInDims);
}
LinearLayout LinearLayout::unsqueezeOut(StringAttr dim) const {
assert(getOutDimSize(dim) == 1);
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
for (auto [outDim, outDimSize] : getOutDims()) {
if (outDim != dim) {
newOutDims.push_back({outDim, outDimSize});
}
}
return LinearLayout(bases, newOutDims, isSurjective());
}
llvm::MapVector<StringAttr, int32_t>
LinearLayout::getFreeVariableMasks() const {
std::unique_ptr<uint64_t[]> mat = getMatrix(*this);
int numRows = getTotalOutDimSizeLog2();
int numCols = getTotalInDimSizeLog2();
assert(numCols <= 64);
f2reduce::inplace_rref_strided(mat.get(), numRows, numCols, 1);
std::set<int32_t> basicVars;
for (int r = 0; r < numRows; r++) {
if (mat[r] == 0) {
continue;
}
basicVars.insert(__builtin_ctzll(mat[r]));
}
llvm::MapVector<StringAttr, int32_t> ret;
int c = 0;
for (StringAttr dim : getInDimNames()) {
int32_t mask = 0;
for (int i = 0; i < getInDimSizeLog2(dim); i++, c++) {
if (basicVars.count(c) == 0) {
mask |= (1 << i);
}
}
ret[dim] = mask;
}
return ret;
}
LinearLayout LinearLayout::removeZeroBasesAlongDim(StringAttr stripDim) const {
LinearLayout::BasesT result;
for (auto &[inDim, inDimBases] : getBases()) {
auto &newInDimBases = result[inDim];
if (inDim != stripDim) {
newInDimBases = inDimBases;
continue;
}
for (auto &basis : inDimBases) {
if (llvm::any_of(basis, [](int32_t val) { return val != 0; })) {
newInDimBases.push_back(basis);
}
}
}
SmallVector<std::pair<StringAttr, int32_t>> newOutDimSizes;
for (auto outDim : getOutDimNames()) {
newOutDimSizes.push_back({outDim, getOutDimSize(outDim)});
}
auto newLayout = LinearLayout(std::move(result), ArrayRef(newOutDimSizes),
this->isSurjective());
return newLayout;
}
size_t hash_value(const LinearLayout &layout) {
size_t seed = 0;
for (const auto &base : layout.getBases()) {
seed = llvm::hash_combine(seed, base.first);
for (const auto &vec : base.second) {
for (int32_t val : vec) {
seed = llvm::hash_combine(seed, val);
}
}
}
for (const auto &outDim : layout.getOutDimNames()) {
seed = llvm::hash_combine(seed, outDim, layout.getOutDimSize(outDim));
}
return seed;
}
bool operator==(const LinearLayout &lhs, const LinearLayout &rhs) {
if (!lhs.equalIgnoringOutDimSizes(rhs))
return false;
for (const auto &[lhsOutDimAndSize, rhsOutDimAndSize] :
llvm::zip(lhs.outDims, rhs.outDims)) {
if (lhsOutDimAndSize.second != rhsOutDimAndSize.second)
return false;
}
return true;
}
bool LinearLayout::equalIgnoringOutDimSizes(const LinearLayout &other) const {
if (llvm::to_vector(this->getOutDimNames()) !=
llvm::to_vector(other.getOutDimNames()))
return false;
if (this->bases.size() != other.bases.size())
return false;
for (auto it1 = this->bases.begin(), it2 = other.bases.begin();
it1 != this->bases.end(); ++it1, ++it2) {
if (*it1 != *it2)
return false;
}
return true;
}
std::string LinearLayout::toString() const {
std::string ret = "\n";
std::string outDimsStr =
"[" +
join(outDims, ", ",
[](auto dimAndSize) {
auto [outDim, size] = dimAndSize;
return outDim.str() + " (size " + std::to_string(size) + ")";
}) +
"]";
if (bases.empty()) {
if (outDims.empty()) {
return "\n(empty layout)";
} else {
return "\n(empty layout with out-dims " + outDimsStr + ")";
}
}
for (const auto &[inDim, inDimBases] : bases) {
if (inDimBases.empty()) {
ret += " - " + inDim.str() + " is a size 1 dimension\n";
continue;
}
ret += " - " +
join(llvm::seq(inDimBases.size()), "\n ",
[&, &inDim = inDim, &inDimBases = inDimBases](int i) {
return inDim.str() + "=" + std::to_string(1 << i) + " -> (" +
join(inDimBases[i], ", ") + ")";
}) +
"\n";
}
ret += "where out dims are: " + outDimsStr;
return ret;
}
LinearLayout ColumnAction::apply(const LinearLayout &layout) const {
assert(layout.hasInDim(inDim));
assert(layout.getInDimSizeLog2(inDim) == inSizeLog2 &&
"Layout has a different size than the ColumnAction");
if (m_isIdentity) {
return layout;
}
auto bases = layout.getBases();
const auto &basesInDim = bases[inDim];
std::vector<std::vector<int32_t>> newBases;
newBases.reserve(action.size());
for (size_t a : action) {
newBases.push_back(basesInDim[a]);
}
bases[inDim] = std::move(newBases);
SmallVector<std::pair<StringAttr, int32_t>> outDims;
for (auto outDim : layout.getOutDimNames()) {
outDims.emplace_back(outDim, layout.getOutDimSize(outDim));
}
return LinearLayout(std::move(bases), std::move(outDims),
false);
}
SmallVector<Value> ColumnAction::apply(ValueRange values) const {
assert(values.size() == (1 << inSizeLog2) &&
"Values have a different size than the ColumnAction");
assert(inDim.str() == "register" && "Values are in registers, so we can only "
"apply ColumnAction to registers");
if (m_isIdentity) {
return values;
}
auto permLL = apply(LinearLayout::identity1D(values.size(), inDim, inDim));
SmallVector<Value> ret;
ret.reserve(permLL.getInDimSize(inDim));
for (int i = 0; i < permLL.getInDimSize(inDim); i++) {
int32_t srcIdx = permLL.apply({{inDim, i}}).begin()->second;
ret.push_back(values[srcIdx]);
}
return ret;
}
ColumnAction ColumnAction::inverse() const {
auto invPerm = SmallVector<size_t>(action.size());
for (size_t i = 0; i < action.size(); i++) {
invPerm[action[i]] = i;
}
return ColumnAction(invPerm, inDim, inSizeLog2);
}
std::string ColumnAction::toString() const {
std::string ret = "ColumnAction([";
ret += join(action, ", ");
ret += "], " + inDim.str() + ", " + std::to_string(inSizeLog2) + ")";
return ret;
}
}