#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "llvm/Support/Compiler.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include <memory>
using namespace mlir;
using namespace mlir::sparse_tensor;
namespace {
#define FOREVERY_BINOP(DO) \
DO(mulf, TensorExp::Kind::kMulF) \
DO(mulc, TensorExp::Kind::kMulC) \
DO(muli, TensorExp::Kind::kMulI) \
DO(addf, TensorExp::Kind::kAddF) \
DO(addc, TensorExp::Kind::kAddC) \
DO(addi, TensorExp::Kind::kAddI) \
DO(subf, TensorExp::Kind::kSubF) \
DO(subc, TensorExp::Kind::kSubC) \
DO(subi, TensorExp::Kind::kSubI) \
DO(andi, TensorExp::Kind::kAndI) \
DO(xori, TensorExp::Kind::kXorI) \
DO(ori, TensorExp::Kind::kOrI) \
DO(cmpf, TensorExp::Kind::kCmpF) \
DO(cmpi, TensorExp::Kind::kCmpI)
#define FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, EXTRA) \
TEST(addf, EXTRA) \
TEST(addc, EXTRA) \
TEST(addi, EXTRA) \
TEST(xori, EXTRA) \
TEST(ori, EXTRA)
#define FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, EXTRA) \
TEST(mulf, EXTRA) \
TEST(mulc, EXTRA) \
TEST(muli, EXTRA) \
TEST(andi, EXTRA)
#define FOREVERY_COMMON_DISJ_BINOP(TEST) \
FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, "")
#define FOREVERY_COMMON_CONJ_BINOP(TEST) \
FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, "")
#define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST) \
FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addf) \
FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addc) \
FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addi) \
FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, xori) \
FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, ori)
#define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST) \
FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulf) \
FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulc) \
FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, muli) \
FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, andi)
#define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST) \
FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addf) \
FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addc) \
FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addi) \
FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, ori) \
FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, xori)
struct Match {
struct Children {
Children(const Match &e0, const Match &e1) : e0(e0), e1(e1) {}
const Match &e0;
const Match &e1;
};
Match() : kind(TensorExp::Kind::kSynZero) {}
Match(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
Match(TensorExp::Kind kind, const Match &e0, const Match &e1)
: kind(kind), children(e0, e1) {
assert(kind >= TensorExp::Kind::kMulF);
}
TensorExp::Kind kind;
union {
TensorId tid;
Children children;
};
};
static Match tensorMatch(TensorId tid) { return Match(tid); }
static Match synZeroMatch() { return Match(); }
#define IMPL_BINOP_PATTERN(OP, KIND) \
LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0, \
const Match &e1) { \
return Match(KIND, e0, e1); \
}
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
#undef IMPL_BINOP_PATTERN
class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
protected:
MergerTestBase(unsigned numTensors, unsigned numLoops)
: merger(numTensors, numLoops, numLoops) {
tensors.reserve(numTensors);
for (unsigned t = 0; t < numTensors; t++)
tensors.push_back(merger.addTensorExp(tid(t)));
}
TensorId tid(unsigned t) const { return merger.makeTensorId(t); }
LoopId lid(unsigned i) const { return merger.makeLoopId(i); }
ExprId tensor(unsigned t) const {
assert(t < tensors.size());
return tensors[t];
}
#define IMPL_BINOP_EXPR(OP, KIND) \
LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) { \
return merger.addExp(KIND, e0, e1); \
}
FOREVERY_BINOP(IMPL_BINOP_EXPR)
#undef IMPL_BINOP_EXPR
bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n,
const Match &pattern, const BitVector &bits,
bool simple) {
for (unsigned k = lo, hi = lo + n; k < hi; ++k) {
if (compareExpression(merger.lat(merger.set(s)[k]).exp, pattern) &&
compareBits(s, k, bits, simple))
return true;
}
return false;
}
void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n,
const Match &pattern, const BitVector &bits,
bool simple = false) {
EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple));
}
void expectLatPoint(LatSetId s, unsigned lo, const Match &pattern,
const BitVector &bits, bool simple = false) {
EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple));
}
BitVector loopsToBits(const std::vector<std::pair<LoopId, TensorId>> &loops) {
BitVector testBits = BitVector(merger.getNumTensors(), false);
for (auto [loop, tensor] : loops)
testBits.set(merger.makeTensorLoopId(tensor, loop));
return testBits;
}
bool compareBits(LatSetId s, unsigned k, const BitVector &bits, bool simple) {
const auto &point = merger.lat(merger.set(s)[k]);
return (simple ? point.simple : point.bits) == bits;
}
void expectNumLatPoints(LatSetId s, unsigned n) {
EXPECT_THAT(merger.set(s).size(), n);
}
bool compareExpression(ExprId e, const Match &pattern) {
const auto &tensorExp = merger.exp(e);
if (tensorExp.kind != pattern.kind)
return false;
switch (tensorExp.kind) {
case TensorExp::Kind::kTensor:
return tensorExp.tensor == pattern.tid;
case TensorExp::Kind::kSynZero:
return true;
case TensorExp::Kind::kInvariant:
llvm_unreachable("invariant not handled yet");
case TensorExp::Kind::kLoopVar:
llvm_unreachable("loop-variables not handled yet");
case TensorExp::Kind::kAbsF:
case TensorExp::Kind::kAbsC:
case TensorExp::Kind::kAbsI:
case TensorExp::Kind::kCeilF:
case TensorExp::Kind::kFloorF:
case TensorExp::Kind::kSqrtF:
case TensorExp::Kind::kSqrtC:
case TensorExp::Kind::kExpm1F:
case TensorExp::Kind::kExpm1C:
case TensorExp::Kind::kLog1pF:
case TensorExp::Kind::kLog1pC:
case TensorExp::Kind::kRelu:
case TensorExp::Kind::kSinF:
case TensorExp::Kind::kSinC:
case TensorExp::Kind::kTanhF:
case TensorExp::Kind::kTanhC:
case TensorExp::Kind::kNegF:
case TensorExp::Kind::kNegC:
case TensorExp::Kind::kNegI:
case TensorExp::Kind::kTruncF:
case TensorExp::Kind::kExtF:
case TensorExp::Kind::kCastFS:
case TensorExp::Kind::kCastFU:
case TensorExp::Kind::kCastSF:
case TensorExp::Kind::kCastUF:
case TensorExp::Kind::kCastS:
case TensorExp::Kind::kCastU:
case TensorExp::Kind::kCastIdx:
case TensorExp::Kind::kTruncI:
case TensorExp::Kind::kCIm:
case TensorExp::Kind::kCRe:
case TensorExp::Kind::kBitCast:
case TensorExp::Kind::kSelect:
case TensorExp::Kind::kBinaryBranch:
case TensorExp::Kind::kUnary:
return compareExpression(tensorExp.children.e0, pattern.children.e0);
case TensorExp::Kind::kMulF:
case TensorExp::Kind::kMulC:
case TensorExp::Kind::kMulI:
case TensorExp::Kind::kDivF:
case TensorExp::Kind::kDivC:
case TensorExp::Kind::kDivS:
case TensorExp::Kind::kDivU:
case TensorExp::Kind::kAddF:
case TensorExp::Kind::kAddC:
case TensorExp::Kind::kAddI:
case TensorExp::Kind::kSubF:
case TensorExp::Kind::kSubC:
case TensorExp::Kind::kSubI:
case TensorExp::Kind::kAndI:
case TensorExp::Kind::kOrI:
case TensorExp::Kind::kXorI:
case TensorExp::Kind::kCmpF:
case TensorExp::Kind::kCmpI:
case TensorExp::Kind::kShrS:
case TensorExp::Kind::kShrU:
case TensorExp::Kind::kShlI:
case TensorExp::Kind::kBinary:
case TensorExp::Kind::kReduce:
return compareExpression(tensorExp.children.e0, pattern.children.e0) &&
compareExpression(tensorExp.children.e1, pattern.children.e1);
case TensorExp::Kind::kDenseOp: {
bool eq = compareExpression(tensorExp.children.e0, pattern.children.e0);
if (eq && tensorExp.children.e1 != sparse_tensor::detail::kInvalidId)
return compareExpression(tensorExp.children.e1, pattern.children.e1);
return eq;
}
}
llvm_unreachable("unexpected kind");
}
Merger merger;
private:
SmallVector<ExprId> tensors;
};
class MergerTest3T1L : public MergerTestBase {
protected:
MergerTest3T1L() : MergerTestBase(3, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
}
};
INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
class MergerTest4T1L : public MergerTestBase {
protected:
MergerTest4T1L() : MergerTestBase(4, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
}
};
INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
class MergerTest3T1LD : public MergerTestBase {
protected:
MergerTest3T1LD() : MergerTestBase(3, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
}
};
INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
class MergerTest4T1LU : public MergerTestBase {
protected:
MergerTest4T1LU() : MergerTestBase(4, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(3));
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
}
};
INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
::testing::Values(LevelFormat::Dense,
LevelFormat::Batch));
class MergerTest3T1LSo : public MergerTestBase {
protected:
MergerTest3T1LSo() : MergerTestBase(3, 1) {
EXPECT_TRUE(merger.getOutTensorID() == tid(2));
EXPECT_TRUE(merger.getSynTensorID() == tid(3));
merger.setHasSparseOut(true);
merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef);
merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
}
};
INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
::testing::Values(LevelFormat::Dense));
}
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2) \
TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const auto t2 = tid(2); \
const Match &p0 = tensorMatch(t0); \
const Match &p1 = tensorMatch(t1); \
const Match &p2 = tensorMatch(t2); \
auto s = merger.buildLattices(e, l0); \
expectNumLatPoints(s, 1); \
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t1}}), true); \
}
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
#undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2) \
TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const auto t2 = tid(2); \
const auto t3 = tid(3); \
const Match &p0 = tensorMatch(t0); \
const Match &p1 = tensorMatch(t1); \
const Match &p2 = tensorMatch(t2); \
auto s = merger.buildLattices(e, l0); \
expectNumLatPoints(s, 1); \
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}})); \
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t3}}), true); \
}
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
#undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT
#define IMPL_MERGER_TEST_DISJ(OP, UNUSED) \
TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const Match &p0 = tensorMatch(t0); \
const Match &p1 = tensorMatch(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 3); \
expectLatPoint(s, 0, OP##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 3); \
expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
true); \
expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true); \
expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true); \
}
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
#undef IMPL_MERGER_TEST_DISJ
#define IMPL_MERGER_TEST_CONJ(OP, UNUSED) \
TEST_P(MergerTest3T1L, vector_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const Match &p0 = tensorMatch(t0); \
const Match &p1 = tensorMatch(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 1); \
expectLatPoint(s, 0, OP##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
true); \
}
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
#undef IMPL_MERGER_TEST_CONJ
#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
const auto em = CONJ##Expr(tensor(0), tensor(1)); \
const auto e = DISJ##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const auto t2 = tid(2); \
const Match &p0 = tensorMatch(t0); \
const Match &p1 = tensorMatch(t1); \
const Match &p2 = tensorMatch(t2); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 3); \
expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 3); \
expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}})); \
}
FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
#undef IMPL_MERGER_TEST_CONJ_DISJ
#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
const auto em = DISJ1##Expr(tensor(0), tensor(1)); \
const auto e = DISJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const auto t2 = tid(2); \
const Match &p0 = tensorMatch(t0); \
const Match &p1 = tensorMatch(t1); \
const Match &p2 = tensorMatch(t2); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 7); \
expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \
loopsToBits({{l0, t1}, {l0, t2}})); \
expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \
loopsToBits({{l0, t0}, {l0, t2}})); \
expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 7); \
expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2), \
loopsToBits({{l0, t1}, {l0, t2}})); \
expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2), \
loopsToBits({{l0, t0}, {l0, t2}})); \
expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}})); \
expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}})); \
expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}})); \
}
FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
#undef IMPL_MERGER_TEST_DISJ_DISJ
#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
const auto em = CONJ1##Expr(tensor(0), tensor(1)); \
const auto e = CONJ2##Expr(em, tensor(2)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const auto t2 = tid(2); \
const Match &p0 = tensorMatch(t0); \
const Match &p1 = tensorMatch(t1); \
const Match &p2 = tensorMatch(t2); \
auto s = merger.buildLattices(e, l0); \
expectNumLatPoints(s, 1); \
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2), \
loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \
}
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
#undef IMPL_MERGER_TEST_CONJ_CONJ
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED) \
TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const Match &p0 = tensorMatch(t0); \
const Match &p1 = tensorMatch(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 3); \
expectLatPoint(s, 0, OP##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}})); \
expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 2); \
expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
true); \
expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true); \
}
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED) \
TEST_P(MergerTest3T1LD, vector_opted_##OP) { \
const auto e = OP##Expr(tensor(0), tensor(1)); \
const auto l0 = lid(0); \
const auto t0 = tid(0); \
const auto t1 = tid(1); \
const Match &p0 = tensorMatch(t0); \
const Match &p1 = tensorMatch(t1); \
auto s = merger.buildLattices(e, l0); \
\
expectNumLatPoints(s, 1); \
expectLatPoint(s, 0, OP##Match(p0, p1), \
loopsToBits({{l0, t0}, {l0, t1}})); \
\
s = merger.optimizeSet(s); \
expectNumLatPoints(s, 1); \
expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}}), true); \
}
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
TEST_P(MergerTest3T1L, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0);
const auto t0 = tid(0);
const auto t1 = tid(1);
const Match &zero = synZeroMatch();
const Match &p0 = tensorMatch(t0);
const Match &p1 = tensorMatch(t1);
auto s = merger.buildLattices(e, l0);
expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
loopsToBits({{l0, t0}}));
expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
loopsToBits({{l0, t1}}));
s = merger.optimizeSet(s);
expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
loopsToBits({{l0, t0}}));
expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
loopsToBits({{l0, t1}}));
}
TEST_P(MergerTest3T1LD, vector_cmp) {
const auto e = cmpiExpr(tensor(0), tensor(1));
const auto l0 = lid(0);
const auto t0 = tid(0);
const auto t1 = tid(1);
const Match &zero = synZeroMatch();
const Match &p0 = tensorMatch(t0);
const Match &p1 = tensorMatch(t1);
auto s = merger.buildLattices(e, l0);
expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
loopsToBits({{l0, t0}}));
expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
loopsToBits({{l0, t1}}));
s = merger.optimizeSet(s);
expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
loopsToBits({{l0, t1}}));
}