#include "triton/Tools/GenericSwizzling.h"
#include "triton/Tools/LayoutUtils.h"
#include "triton/Tools/LinearLayout.h"
#include "mlir/Support/LLVM.h"
#include "llvm/Support/Signals.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
using namespace mlir;
using namespace mlir::triton;
using mlir::triton::gpu::logBankConflictsLdSt;
using mlir::triton::gpu::optimalSwizzling;
using mlir::triton::gpu::optimalSwizzlingLdSt;
namespace {
SmallVector<int32_t> flatten(const LinearLayout &ll, StringAttr dim) {
assert(ll.hasInDim(dim) && "in dim must exist");
SmallVector<int32_t> vec;
for (const auto &basis : ll.getBases().lookup(dim)) {
assert(basis.size() == 1 && "basis must be a single int32_t");
vec.push_back(basis[0]);
}
return vec;
}
class SwizzleTest : public ::testing::Test {
public:
StringAttr S(StringRef str) { return StringAttr::get(&ctx, str); }
protected:
MLIRContext ctx;
};
TEST_F(SwizzleTest, Test128x128Float8Transpose) {
LinearLayout matrix(
{{S("register"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}, {2, 0}}},
{S("lane"), {{0, 16}, {0, 32}, {0, 64}, {4, 0}, {8, 0}}},
{S("warp"), {{16, 0}, {32, 0}, {64, 0}}}},
{{S("dim0"), 128}, {S("dim1"), 128}}, true);
auto matrix_t = transposeLinearLayout(matrix, {1, 0});
auto smem = optimalSwizzlingLdSt(matrix, matrix_t, 8);
auto [r, w] = logBankConflictsLdSt(matrix, matrix_t, smem, 8);
EXPECT_EQ(r, 0);
EXPECT_EQ(w, 0);
}
TEST_F(SwizzleTest, Test16x16Bf16BlockedMma) {
LinearLayout blocked({{S("register"), {{0, 1}, {0, 2}, {0, 4}}},
{S("lane"), {{0, 8}, {1, 0}, {2, 0}, {4, 0}, {8, 0}}},
{S("warp"), {}}},
{{S("dim0"), 16}, {S("dim1"), 16}},
true);
LinearLayout mma({{S("register"), {{0, 1}, {8, 0}, {0, 8}}},
{S("lane"), {{0, 2}, {0, 4}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {}}},
{{S("dim0"), 16}, {S("dim1"), 16}},
true);
auto smem = optimalSwizzlingLdSt(blocked, mma, 16);
auto [r, w] = logBankConflictsLdSt(blocked, mma, smem, 16);
EXPECT_EQ(r, 0);
EXPECT_EQ(w, 0);
}
TEST_F(SwizzleTest, Test16x256U4Mma) {
LinearLayout blocked(
{{S("register"),
{{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {4, 0}, {8, 0}}},
{S("lane"), {{0, 32}, {0, 64}, {0, 128}, {1, 0}, {2, 0}}},
{S("warp"), {}}},
{{S("dim0"), 16}, {S("dim1"), 256}}, true);
LinearLayout mma(
{{S("register"),
{{0, 1}, {0, 2}, {0, 4}, {8, 0}, {0, 32}, {0, 64}, {0, 128}}},
{S("lane"), {{0, 8}, {0, 16}, {1, 0}, {2, 0}, {4, 0}}},
{S("warp"), {}}},
{{S("dim0"), 16}, {S("dim1"), 256}}, true);
auto smem = optimalSwizzlingLdSt(blocked, mma, 4);
auto [r, w] = logBankConflictsLdSt(blocked, mma, smem, 4);
EXPECT_EQ(r, 0);
EXPECT_EQ(w, 0);
}
TEST_F(SwizzleTest, Test32x16F32Transpose) {
LinearLayout matrix({{S("register"), {{4, 0}, {8, 0}, {16, 0}}},
{S("lane"), {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {1, 0}}},
{S("warp"), {{2, 0}}}},
{{S("dim0"), 32}, {S("dim1"), 16}},
true);
LinearLayout matrix_t({{S("register"), {{0, 2}, {0, 4}, {0, 8}}},
{S("lane"), {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}}},
{S("warp"), {{0, 1}}}},
{{S("dim0"), 32}, {S("dim1"), 16}},
true);
auto smem = optimalSwizzlingLdSt(matrix, matrix_t, 32);
auto [r, w] = logBankConflictsLdSt(matrix, matrix_t, smem, 32);
EXPECT_EQ(r, 0);
EXPECT_EQ(w, 0);
}
TEST_F(SwizzleTest, Test128x128F16Transpose) {
LinearLayout matrix(
{{S("register"), {{1, 0}, {2, 0}, {4, 0}, {0, 32}, {0, 64}}},
{S("lane"), {{8, 0}, {16, 0}, {32, 0}, {64, 0}, {0, 1}}},
{S("warp"), {{0, 2}, {0, 4}, {0, 8}, {0, 16}}}},
{{S("dim0"), 128}, {S("dim1"), 128}},
true);
LinearLayout matrix_t(
{{S("register"), {{0, 1}, {0, 2}, {0, 4}, {32, 0}, {64, 0}}},
{S("lane"), {{0, 8}, {0, 16}, {0, 32}, {0, 64}, {1, 0}}},
{S("warp"), {{2, 0}, {4, 0}, {8, 0}, {16, 0}}}},
{{S("dim0"), 128}, {S("dim1"), 128}},
true);
auto smem = optimalSwizzlingLdSt(matrix, matrix_t, 16);
auto [r, w] = logBankConflictsLdSt(matrix, matrix_t, smem, 16);
EXPECT_EQ(r, 0);
EXPECT_EQ(w, 0);
}
}
int main(int argc, char *argv[]) {
llvm::sys::PrintStackTraceOnErrorSignal(argv[0]);
testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}