//===- MergerTest.cpp - Tests for the sparsifier's merger -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SparseTensor/Utils/Merger.h"
#include "llvm/Support/Compiler.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"

#include <memory>

using namespace mlir;
using namespace mlir::sparse_tensor;

namespace {

///
/// Defines macros to iterate binary and the combination of binary operations.
///

#define FOREVERY_BINOP(DO)                                                     \
  DO(mulf, TensorExp::Kind::kMulF)                                             \
  DO(mulc, TensorExp::Kind::kMulC)                                             \
  DO(muli, TensorExp::Kind::kMulI)                                             \
  DO(addf, TensorExp::Kind::kAddF)                                             \
  DO(addc, TensorExp::Kind::kAddC)                                             \
  DO(addi, TensorExp::Kind::kAddI)                                             \
  DO(subf, TensorExp::Kind::kSubF)                                             \
  DO(subc, TensorExp::Kind::kSubC)                                             \
  DO(subi, TensorExp::Kind::kSubI)                                             \
  DO(andi, TensorExp::Kind::kAndI)                                             \
  DO(xori, TensorExp::Kind::kXorI)                                             \
  DO(ori, TensorExp::Kind::kOrI)                                               \
  DO(cmpf, TensorExp::Kind::kCmpF)                                             \
  DO(cmpi, TensorExp::Kind::kCmpI)

#define FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, EXTRA)                          \
  TEST(addf, EXTRA)                                                            \
  TEST(addc, EXTRA)                                                            \
  TEST(addi, EXTRA)                                                            \
  TEST(xori, EXTRA)                                                            \
  TEST(ori, EXTRA)

#define FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, EXTRA)                          \
  TEST(mulf, EXTRA)                                                            \
  TEST(mulc, EXTRA)                                                            \
  TEST(muli, EXTRA)                                                            \
  TEST(andi, EXTRA)

#define FOREVERY_COMMON_DISJ_BINOP(TEST)                                       \
  FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, "")

#define FOREVERY_COMMON_CONJ_BINOP(TEST)                                       \
  FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, "")

#define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST)                          \
  FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addf)                                 \
  FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addc)                                 \
  FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, addi)                                 \
  FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, xori)                                 \
  FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, ori)

#define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST)                          \
  FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulf)                                 \
  FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, mulc)                                 \
  FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, muli)                                 \
  FOREVERY_COMMON_CONJ_BINOP_EXTRA(TEST, andi)

#define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST)                          \
  FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addf)                                 \
  FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addc)                                 \
  FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, addi)                                 \
  FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, ori)                                  \
  FOREVERY_COMMON_DISJ_BINOP_EXTRA(TEST, xori)

///
/// Helper classes/functions for testing Merger.
///

/// Simple recursive data structure used to match expressions in `Merger`,
/// which uses const references into the short-lived data strucutures.
struct Match {
  struct Children {
    Children(const Match &e0, const Match &e1) : e0(e0), e1(e1) {}
    const Match &e0;
    const Match &e1;
  };

  Match() : kind(TensorExp::Kind::kSynZero) {}
  Match(TensorId tid) : kind(TensorExp::Kind::kTensor), tid(tid) {}
  Match(TensorExp::Kind kind, const Match &e0, const Match &e1)
      : kind(kind), children(e0, e1) {
    assert(kind >= TensorExp::Kind::kMulF);
  }

  TensorExp::Kind kind;
  union {
    TensorId tid;
    Children children;
  };
};

///
/// Readable Match builder functions.
/// These should be preferred over the actual constructors.
///

static Match tensorMatch(TensorId tid) { return Match(tid); }
static Match synZeroMatch() { return Match(); }

#define IMPL_BINOP_PATTERN(OP, KIND)                                           \
  LLVM_ATTRIBUTE_UNUSED static Match OP##Match(const Match &e0,                \
                                               const Match &e1) {              \
    return Match(KIND, e0, e1);                                                \
  }
FOREVERY_BINOP(IMPL_BINOP_PATTERN)
#undef IMPL_BINOP_PATTERN

// Parameterize LevelFormat to test both Dense and Batch LevelFormat.
class MergerTestBase : public ::testing::TestWithParam<LevelFormat> {
protected:
  MergerTestBase(unsigned numTensors, unsigned numLoops)
      : merger(numTensors, numLoops, /*maxRank=*/numLoops) {
    tensors.reserve(numTensors);
    for (unsigned t = 0; t < numTensors; t++)
      tensors.push_back(merger.addTensorExp(tid(t)));
  }

  ///
  /// Expression construction helpers.
  ///

  TensorId tid(unsigned t) const { return merger.makeTensorId(t); }
  LoopId lid(unsigned i) const { return merger.makeLoopId(i); }
  ExprId tensor(unsigned t) const {
    assert(t < tensors.size());
    return tensors[t];
  }

#define IMPL_BINOP_EXPR(OP, KIND)                                              \
  LLVM_ATTRIBUTE_UNUSED ExprId OP##Expr(ExprId e0, ExprId e1) {                \
    return merger.addExp(KIND, e0, e1);                                        \
  }
  FOREVERY_BINOP(IMPL_BINOP_EXPR)
#undef IMPL_BINOP_EXPR

  ///
  /// Comparison helpers.
  ///

  /// Returns true if any lattice point with an expression matching
  /// the given `pattern` and bits matching the given `bits` is present
  /// in the `[lo, lo+n)` slice of the lattice set `s`.  This is useful
  /// for testing partial ordering constraints between lattice points.
  /// We generally know how contiguous groups of lattice points should
  /// be ordered with respect to other groups, but there is no required
  /// ordering within groups.  If `simple` is true, then compare the
  /// `lat.simple` field instead to test the result after optimization.
  bool latPointWithinRange(LatSetId s, unsigned lo, unsigned n,
                           const Match &pattern, const BitVector &bits,
                           bool simple) {
    for (unsigned k = lo, hi = lo + n; k < hi; ++k) {
      if (compareExpression(merger.lat(merger.set(s)[k]).exp, pattern) &&
          compareBits(s, k, bits, simple))
        return true;
    }
    return false;
  }

  /// Wrapper over latPointWithinRange for readability of tests.
  void expectLatPointWithinRange(LatSetId s, unsigned lo, unsigned n,
                                 const Match &pattern, const BitVector &bits,
                                 bool simple = false) {
    EXPECT_TRUE(latPointWithinRange(s, lo, n, pattern, bits, simple));
  }

  /// Wrapper over expectLatPointWithinRange for a single lat point.
  void expectLatPoint(LatSetId s, unsigned lo, const Match &pattern,
                      const BitVector &bits, bool simple = false) {
    EXPECT_TRUE(latPointWithinRange(s, lo, 1, pattern, bits, simple));
  }

  /// Converts a vector of (loop, tensor) pairs to a bitvector with the
  /// corresponding bits set.
  BitVector loopsToBits(const std::vector<std::pair<LoopId, TensorId>> &loops) {
    BitVector testBits = BitVector(merger.getNumTensors(), false);
    for (auto [loop, tensor] : loops)
      testBits.set(merger.makeTensorLoopId(tensor, loop));
    return testBits;
  }

  /// Returns true if the bits of the `k`th point in set `s` matches
  /// the given `bits`.  If `simple` is true, then compares the `lat.simple`
  /// field instead, to test the result after optimization
  bool compareBits(LatSetId s, unsigned k, const BitVector &bits, bool simple) {
    const auto &point = merger.lat(merger.set(s)[k]);
    return (simple ? point.simple : point.bits) == bits;
  }

  /// Check that there are n lattice points in set s.
  void expectNumLatPoints(LatSetId s, unsigned n) {
    EXPECT_THAT(merger.set(s).size(), n);
  }

  /// Compares expressions for equality. Equality is defined recursively as:
  /// - Operations are equal if they have the same kind and children.
  /// - Leaf tensors are equal if they refer to the same tensor.
  bool compareExpression(ExprId e, const Match &pattern) {
    const auto &tensorExp = merger.exp(e);
    if (tensorExp.kind != pattern.kind)
      return false;
    switch (tensorExp.kind) {
    // Leaf.
    case TensorExp::Kind::kTensor:
      return tensorExp.tensor == pattern.tid;
    case TensorExp::Kind::kSynZero:
      // Already checked kind equivalence @L233
      return true;
    case TensorExp::Kind::kInvariant:
      llvm_unreachable("invariant not handled yet");
    case TensorExp::Kind::kLoopVar:
      llvm_unreachable("loop-variables not handled yet");
    // Unary operations.
    case TensorExp::Kind::kAbsF:
    case TensorExp::Kind::kAbsC:
    case TensorExp::Kind::kAbsI:
    case TensorExp::Kind::kCeilF:
    case TensorExp::Kind::kFloorF:
    case TensorExp::Kind::kSqrtF:
    case TensorExp::Kind::kSqrtC:
    case TensorExp::Kind::kExpm1F:
    case TensorExp::Kind::kExpm1C:
    case TensorExp::Kind::kLog1pF:
    case TensorExp::Kind::kLog1pC:
    case TensorExp::Kind::kRelu:
    case TensorExp::Kind::kSinF:
    case TensorExp::Kind::kSinC:
    case TensorExp::Kind::kTanhF:
    case TensorExp::Kind::kTanhC:
    case TensorExp::Kind::kNegF:
    case TensorExp::Kind::kNegC:
    case TensorExp::Kind::kNegI:
    case TensorExp::Kind::kTruncF:
    case TensorExp::Kind::kExtF:
    case TensorExp::Kind::kCastFS:
    case TensorExp::Kind::kCastFU:
    case TensorExp::Kind::kCastSF:
    case TensorExp::Kind::kCastUF:
    case TensorExp::Kind::kCastS:
    case TensorExp::Kind::kCastU:
    case TensorExp::Kind::kCastIdx:
    case TensorExp::Kind::kTruncI:
    case TensorExp::Kind::kCIm:
    case TensorExp::Kind::kCRe:
    case TensorExp::Kind::kBitCast:
    case TensorExp::Kind::kSelect:
    case TensorExp::Kind::kBinaryBranch:
    case TensorExp::Kind::kUnary:
      return compareExpression(tensorExp.children.e0, pattern.children.e0);
    // Binary operations.
    case TensorExp::Kind::kMulF:
    case TensorExp::Kind::kMulC:
    case TensorExp::Kind::kMulI:
    case TensorExp::Kind::kDivF:
    case TensorExp::Kind::kDivC:
    case TensorExp::Kind::kDivS:
    case TensorExp::Kind::kDivU:
    case TensorExp::Kind::kAddF:
    case TensorExp::Kind::kAddC:
    case TensorExp::Kind::kAddI:
    case TensorExp::Kind::kSubF:
    case TensorExp::Kind::kSubC:
    case TensorExp::Kind::kSubI:
    case TensorExp::Kind::kAndI:
    case TensorExp::Kind::kOrI:
    case TensorExp::Kind::kXorI:
    case TensorExp::Kind::kCmpF:
    case TensorExp::Kind::kCmpI:
    case TensorExp::Kind::kShrS:
    case TensorExp::Kind::kShrU:
    case TensorExp::Kind::kShlI:
    case TensorExp::Kind::kBinary:
    case TensorExp::Kind::kReduce:
      return compareExpression(tensorExp.children.e0, pattern.children.e0) &&
             compareExpression(tensorExp.children.e1, pattern.children.e1);
    case TensorExp::Kind::kDenseOp: {
      bool eq = compareExpression(tensorExp.children.e0, pattern.children.e0);
      if (eq && tensorExp.children.e1 != sparse_tensor::detail::kInvalidId)
        return compareExpression(tensorExp.children.e1, pattern.children.e1);
      return eq;
    }
    }
    llvm_unreachable("unexpected kind");
  }

  // This field is public for convenience.
  Merger merger;

private:
  // This field is private to prevent mutation after the ctor.
  SmallVector<ExprId> tensors;
};

///
/// Tests with all sparse inputs.
///

/// Three tensors (two inputs, one output); and a single loop.
class MergerTest3T1L : public MergerTestBase {
protected:
  MergerTest3T1L() : MergerTestBase(3, 1) {
    EXPECT_TRUE(merger.getOutTensorID() == tid(2));
    // Tensor 0: sparse input vector.
    merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
    // Tensor 1: sparse input vector.
    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
    // Tensor 2: dense output vector.
    merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
  }
};

INSTANTIATE_TEST_SUITE_P(Test3T1L, MergerTest3T1L,
                         ::testing::Values(LevelFormat::Dense,
                                           LevelFormat::Batch));

/// Four tensors (three inputs, one output); and a single loop.
class MergerTest4T1L : public MergerTestBase {
protected:
  MergerTest4T1L() : MergerTestBase(4, 1) {
    EXPECT_TRUE(merger.getOutTensorID() == tid(3));
    // Tensor 0: sparse input vector.
    merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
    // Tensor 1: sparse input vector.
    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Compressed);
    // Tensor 2: sparse input vector
    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
    // Tensor 3: dense output vector
    merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
  }
};

INSTANTIATE_TEST_SUITE_P(Test4T1L, MergerTest4T1L,
                         ::testing::Values(LevelFormat::Dense,
                                           LevelFormat::Batch));

///
/// Tests with both sparse and dense input.
///

/// Three tensors (two inputs, one output); and a single loop.
class MergerTest3T1LD : public MergerTestBase {
protected:
  MergerTest3T1LD() : MergerTestBase(3, 1) {
    EXPECT_TRUE(merger.getOutTensorID() == tid(2));
    // Tensor 0: sparse input vector.
    merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Compressed);
    // Tensor 1: dense input vector.
    merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
    // Tensor 2: dense output vector.
    merger.setLevelAndType(tid(2), lid(0), 0, GetParam());
  }
};

INSTANTIATE_TEST_SUITE_P(Test3T1LD, MergerTest3T1LD,
                         ::testing::Values(LevelFormat::Dense,
                                           LevelFormat::Batch));

///
/// Tests with both undef and dense input.
///

/// Three tensors (three inputs, one output); and a single loop.
class MergerTest4T1LU : public MergerTestBase {
protected:
  MergerTest4T1LU() : MergerTestBase(4, 1) {
    EXPECT_TRUE(merger.getOutTensorID() == tid(3));
    // Tensor 0: undef input vector.
    merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
    // Tensor 1: dense input vector.
    merger.setLevelAndType(tid(1), lid(0), 0, GetParam());
    // Tensor 2: undef input vector.
    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Undef);
    // Tensor 3: dense output vector.
    merger.setLevelAndType(tid(3), lid(0), 0, GetParam());
  }
};

INSTANTIATE_TEST_SUITE_P(Test4T1LU, MergerTest4T1LU,
                         ::testing::Values(LevelFormat::Dense,
                                           LevelFormat::Batch));

///
/// Tests with operation on sparse output.
///

/// Three tensors (two inputs, one output, one synthetic); and a single loop.
class MergerTest3T1LSo : public MergerTestBase {
protected:
  MergerTest3T1LSo() : MergerTestBase(3, 1) {
    EXPECT_TRUE(merger.getOutTensorID() == tid(2));
    EXPECT_TRUE(merger.getSynTensorID() == tid(3));
    merger.setHasSparseOut(true);
    // Tensor 0: undef input vector.
    merger.setLevelAndType(tid(0), lid(0), 0, LevelFormat::Undef);
    // Tensor 1: undef input vector.
    merger.setLevelAndType(tid(1), lid(0), 0, LevelFormat::Undef);
    // Tensor 2: sparse output vector.
    merger.setLevelAndType(tid(2), lid(0), 0, LevelFormat::Compressed);
  }
};

// This testsuite does not use any dense-like format, just one of {Dense, Batch}
// is enough.
INSTANTIATE_TEST_SUITE_P(Test3T1LSo, MergerTest3T1LSo,
                         ::testing::Values(LevelFormat::Dense));

} // namespace

/// Vector multiplication (conjunction) of 3 vectors, i.e.;
///   a(i) = b(i) * c(i) * d(i)
/// which should form the single lattice point
/// {
///   lat( i_00_U i_01_D i_02_U / (tensor_0 * tensor_1 * tensor2) )
/// }
/// after optimization, the dense dimesion should be kept, despite it appears
/// in the middle
/// {
///   lat( i_01_D / (tensor_0 * tensor_1 * tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_UNDEF(CONJ1, CONJ2)                         \
  TEST_P(MergerTest4T1LU, vector_##CONJ1##_##CONJ2) {                          \
    const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
    const auto e = CONJ2##Expr(em, tensor(2));                                 \
    const auto l0 = lid(0);                                                    \
    const auto t0 = tid(0);                                                    \
    const auto t1 = tid(1);                                                    \
    const auto t2 = tid(2);                                                    \
    const Match &p0 = tensorMatch(t0);                                         \
    const Match &p1 = tensorMatch(t1);                                         \
    const Match &p2 = tensorMatch(t2);                                         \
    auto s = merger.buildLattices(e, l0);                                      \
    expectNumLatPoints(s, 1);                                                  \
    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
    s = merger.optimizeSet(s);                                                 \
    expectNumLatPoints(s, 1);                                                  \
    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                   loopsToBits({{l0, t1}}), true);                             \
  }
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_UNDEF)
#undef IMPL_MERGER_TEST_CONJ_CONJ_UNDEF

/// Vector multiplication (conjunction) of 2 vectors, i.e.;
///   o(i) = b(i) * c(i) * o(i)
/// which should form the single lattice point (note how a synthetic tensor
/// i_03_U is created for the sparse output)
/// {
///   lat( i_00_U i_01_U i_03_U / (tensor_0 * tensor_1 * output_tensor_2) )
/// }
/// after optimization, the synthetic tensor should be preserved.
/// {
///   lat( i_03_U / (tensor_0 * tensor_1 * output_tensor2) )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT(CONJ1, CONJ2)                    \
  TEST_P(MergerTest3T1LSo, vector_##CONJ1##_##CONJ2) {                         \
    const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
    const auto e = CONJ2##Expr(em, tensor(2));                                 \
    const auto l0 = lid(0);                                                    \
    const auto t0 = tid(0);                                                    \
    const auto t1 = tid(1);                                                    \
    const auto t2 = tid(2);                                                    \
    const auto t3 = tid(3);                                                    \
    const Match &p0 = tensorMatch(t0);                                         \
    const Match &p1 = tensorMatch(t1);                                         \
    const Match &p2 = tensorMatch(t2);                                         \
    auto s = merger.buildLattices(e, l0);                                      \
    expectNumLatPoints(s, 1);                                                  \
    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t3}}));               \
    s = merger.optimizeSet(s);                                                 \
    expectNumLatPoints(s, 1);                                                  \
    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                   loopsToBits({{l0, t3}}), true);                             \
  }
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT)
#undef IMPL_MERGER_TEST_CONJ_CONJ_SPARSE_OUT

/// Vector addition (disjunction) of 2 vectors. i.e.;
///   a(i) = b(i) + c(i)
/// which should form the 3 lattice points
/// {
///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
///   lat( i_00 / tensor_0 )
///   lat( i_01 / tensor_1 )
/// }
/// and after optimization, the lattice points do not change (as there is no
/// duplicated point and all input vectors are sparse vector).
/// {
///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
///   lat( i_00 / tensor_0 )
///   lat( i_01 / tensor_1 )
/// }
#define IMPL_MERGER_TEST_DISJ(OP, UNUSED)                                      \
  TEST_P(MergerTest3T1L, vector_##OP) {                                        \
    const auto e = OP##Expr(tensor(0), tensor(1));                             \
    const auto l0 = lid(0);                                                    \
    const auto t0 = tid(0);                                                    \
    const auto t1 = tid(1);                                                    \
    const Match &p0 = tensorMatch(t0);                                         \
    const Match &p1 = tensorMatch(t1);                                         \
    auto s = merger.buildLattices(e, l0);                                      \
                                                                               \
    expectNumLatPoints(s, 3);                                                  \
    expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
                   loopsToBits({{l0, t0}, {l0, t1}}));                         \
    expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}));           \
    expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}));           \
                                                                               \
    s = merger.optimizeSet(s);                                                 \
    expectNumLatPoints(s, 3);                                                  \
    expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
                   true);                                                      \
    expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}), true);     \
    expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}), true);     \
  }
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
#undef IMPL_MERGER_TEST_DISJ

/// Vector multiplication (conjunction) of 2 vectors, i.e.;
///   a(i) = b(i) * c(i)
/// which should form the single lattice point
/// {
///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
/// }
#define IMPL_MERGER_TEST_CONJ(OP, UNUSED)                                      \
  TEST_P(MergerTest3T1L, vector_##OP) {                                        \
    const auto e = OP##Expr(tensor(0), tensor(1));                             \
    const auto l0 = lid(0);                                                    \
    const auto t0 = tid(0);                                                    \
    const auto t1 = tid(1);                                                    \
    const Match &p0 = tensorMatch(t0);                                         \
    const Match &p1 = tensorMatch(t1);                                         \
    auto s = merger.buildLattices(e, l0);                                      \
                                                                               \
    expectNumLatPoints(s, 1);                                                  \
    expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
                   loopsToBits({{l0, t0}, {l0, t1}}));                         \
                                                                               \
    s = merger.optimizeSet(s);                                                 \
    expectNumLatPoints(s, 1);                                                  \
    expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
                   true);                                                      \
  }
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
#undef IMPL_MERGER_TEST_CONJ

/// Vector multiplication (conjunction) then addition (disjunction), i.e.;
///   a(i) = b(i) * c(i) + d(i);
/// which should form
/// {
///    lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
///    lat( i_00 i_01 / tensor_0 * tensor_1
///    lat( i_02 / tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ)                                 \
  TEST_P(MergerTest4T1L, vector_##CONJ##_##DISJ) {                             \
    const auto em = CONJ##Expr(tensor(0), tensor(1));                          \
    const auto e = DISJ##Expr(em, tensor(2));                                  \
    const auto l0 = lid(0);                                                    \
    const auto t0 = tid(0);                                                    \
    const auto t1 = tid(1);                                                    \
    const auto t2 = tid(2);                                                    \
    const Match &p0 = tensorMatch(t0);                                         \
    const Match &p1 = tensorMatch(t1);                                         \
    const Match &p2 = tensorMatch(t2);                                         \
    auto s = merger.buildLattices(e, l0);                                      \
                                                                               \
    expectNumLatPoints(s, 3);                                                  \
    expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2),                 \
                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
    expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1),                    \
                              loopsToBits({{l0, t0}, {l0, t1}}));              \
    expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}}));           \
                                                                               \
    s = merger.optimizeSet(s);                                                 \
    expectNumLatPoints(s, 3);                                                  \
    expectLatPoint(s, 0, DISJ##Match(CONJ##Match(p0, p1), p2),                 \
                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
    expectLatPointWithinRange(s, 1, 2, CONJ##Match(p0, p1),                    \
                              loopsToBits({{l0, t0}, {l0, t1}}));              \
    expectLatPointWithinRange(s, 1, 2, p2, loopsToBits({{l0, t2}}));           \
  }
FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
#undef IMPL_MERGER_TEST_CONJ_DISJ

/// Vector addition (disjunction) then addition (disjunction), i.e.;
///   a(i) = b(i) + c(i) + d(i)
/// which should form
/// {
///   lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
///   lat( i_02 i_01 / tensor_2 + tensor_1 )
///   lat( i_02 i_00 / tensor_2 + tensor_0 )
///   lat( i_01 i_00 / tensor_1 + tensor_0 )
///   lat( i_02 / tensor_2 )
///   lat( i_01 / tensor_1 )
///   lat( i_00 / tensor_0 )
/// }
#define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2)                               \
  TEST_P(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) {                           \
    const auto em = DISJ1##Expr(tensor(0), tensor(1));                         \
    const auto e = DISJ2##Expr(em, tensor(2));                                 \
    const auto l0 = lid(0);                                                    \
    const auto t0 = tid(0);                                                    \
    const auto t1 = tid(1);                                                    \
    const auto t2 = tid(2);                                                    \
    const Match &p0 = tensorMatch(t0);                                         \
    const Match &p1 = tensorMatch(t1);                                         \
    const Match &p2 = tensorMatch(t2);                                         \
    auto s = merger.buildLattices(e, l0);                                      \
                                                                               \
    expectNumLatPoints(s, 7);                                                  \
    expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2),               \
                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
    expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2),                   \
                              loopsToBits({{l0, t1}, {l0, t2}}));              \
    expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2),                   \
                              loopsToBits({{l0, t0}, {l0, t2}}));              \
    expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1),                   \
                              loopsToBits({{l0, t0}, {l0, t1}}));              \
    expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}}));           \
    expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}}));           \
    expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}}));           \
                                                                               \
    s = merger.optimizeSet(s);                                                 \
    expectNumLatPoints(s, 7);                                                  \
    expectLatPoint(s, 0, DISJ2##Match(DISJ1##Match(p0, p1), p2),               \
                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
    expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p1, p2),                   \
                              loopsToBits({{l0, t1}, {l0, t2}}));              \
    expectLatPointWithinRange(s, 1, 6, DISJ2##Match(p0, p2),                   \
                              loopsToBits({{l0, t0}, {l0, t2}}));              \
    expectLatPointWithinRange(s, 1, 6, DISJ1##Match(p0, p1),                   \
                              loopsToBits({{l0, t0}, {l0, t1}}));              \
    expectLatPointWithinRange(s, 1, 6, p2, loopsToBits({{l0, t2}}));           \
    expectLatPointWithinRange(s, 1, 6, p1, loopsToBits({{l0, t1}}));           \
    expectLatPointWithinRange(s, 1, 6, p0, loopsToBits({{l0, t0}}));           \
  }
FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
#undef IMPL_MERGER_TEST_DISJ_DISJ

/// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
///   a(i) = b(i) * c(i) * d(i);
/// which should form
/// {
///    lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
/// }
#define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2)                               \
  TEST_P(MergerTest4T1L, vector_##CONJ1##_##CONJ2) {                           \
    const auto em = CONJ1##Expr(tensor(0), tensor(1));                         \
    const auto e = CONJ2##Expr(em, tensor(2));                                 \
    const auto l0 = lid(0);                                                    \
    const auto t0 = tid(0);                                                    \
    const auto t1 = tid(1);                                                    \
    const auto t2 = tid(2);                                                    \
    const Match &p0 = tensorMatch(t0);                                         \
    const Match &p1 = tensorMatch(t1);                                         \
    const Match &p2 = tensorMatch(t2);                                         \
    auto s = merger.buildLattices(e, l0);                                      \
    expectNumLatPoints(s, 1);                                                  \
    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
    s = merger.optimizeSet(s);                                                 \
    expectNumLatPoints(s, 1);                                                  \
    expectLatPoint(s, 0, CONJ2##Match(CONJ1##Match(p0, p1), p2),               \
                   loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true);         \
  }
FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
#undef IMPL_MERGER_TEST_CONJ_CONJ

/// Vector addition (disjunction) of 2 vectors, i.e.;
///   a(i) = b(i) + c(i)
/// which should form the 3 lattice points
/// {
///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
///   lat( i_00 / sparse_tensor_0 )
///   lat( i_01 / dense_tensor_1 )
/// }
/// which should be optimized to
/// {
///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
///   lat( i_01 / dense_tensor_0 ) (no sparse dimension)
/// }
///
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
#define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP, UNUSED)                            \
  TEST_P(MergerTest3T1LD, vector_opted_##OP) {                                 \
    const auto e = OP##Expr(tensor(0), tensor(1));                             \
    const auto l0 = lid(0);                                                    \
    const auto t0 = tid(0);                                                    \
    const auto t1 = tid(1);                                                    \
    const Match &p0 = tensorMatch(t0);                                         \
    const Match &p1 = tensorMatch(t1);                                         \
    auto s = merger.buildLattices(e, l0);                                      \
                                                                               \
    expectNumLatPoints(s, 3);                                                  \
    expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
                   loopsToBits({{l0, t0}, {l0, t1}}));                         \
    expectLatPointWithinRange(s, 1, 2, p0, loopsToBits({{l0, t0}}));           \
    expectLatPointWithinRange(s, 1, 2, p1, loopsToBits({{l0, t1}}));           \
                                                                               \
    s = merger.optimizeSet(s);                                                 \
    expectNumLatPoints(s, 2);                                                  \
    expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}), \
                   true);                                                      \
    expectLatPoint(s, 1, p1, loopsToBits({{l0, t1}}), true);                   \
  }
FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ

/// Vector multiplication (conjunction) of 2 vectors, i.e.:
///   a(i) = b(i) * c(i)
/// which should form the single lattice point
/// {
///   lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
/// }
/// it should be optimized to
/// {
///   lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
/// }
/// since i_01 is a dense dimension.
#define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP, UNUSED)                            \
  TEST_P(MergerTest3T1LD, vector_opted_##OP) {                                 \
    const auto e = OP##Expr(tensor(0), tensor(1));                             \
    const auto l0 = lid(0);                                                    \
    const auto t0 = tid(0);                                                    \
    const auto t1 = tid(1);                                                    \
    const Match &p0 = tensorMatch(t0);                                         \
    const Match &p1 = tensorMatch(t1);                                         \
    auto s = merger.buildLattices(e, l0);                                      \
                                                                               \
    expectNumLatPoints(s, 1);                                                  \
    expectLatPoint(s, 0, OP##Match(p0, p1),                                    \
                   loopsToBits({{l0, t0}, {l0, t1}}));                         \
                                                                               \
    s = merger.optimizeSet(s);                                                 \
    expectNumLatPoints(s, 1);                                                  \
    expectLatPoint(s, 0, OP##Match(p0, p1), loopsToBits({{l0, t0}}), true);    \
  }
FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
#undef IMPL_MERGER_TEST_OPTIMIZED_CONJ

/// Vector element-wise comparison (disjunction) of 2 vectors. i.e.;
///   a(i) = b(i) + c(i)
/// which should form the 3 lattice points
/// {
///   lat( i_00 i_01 / (tensor_0 cmp tensor_1) )
///   lat( i_00 / tensor_0 cmp 0 )
///   lat( i_01 / 0 cmp tensor_1 )
/// }
/// and after optimization, the lattice points do not change (as there is no
/// duplicated point and all input vectors are sparse vector).
/// {
///   lat( i_00 i_01 / (tensor_0 cmp tensor_1) )
///   lat( i_00 / tensor_0 cmp 0 )
///   lat( i_01 / 0 cmp tensor_1 )
/// }
TEST_P(MergerTest3T1L, vector_cmp) {
  const auto e = cmpiExpr(tensor(0), tensor(1));
  const auto l0 = lid(0);
  const auto t0 = tid(0);
  const auto t1 = tid(1);
  const Match &zero = synZeroMatch();
  const Match &p0 = tensorMatch(t0);
  const Match &p1 = tensorMatch(t1);
  auto s = merger.buildLattices(e, l0);
  expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
  expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
                            loopsToBits({{l0, t0}}));
  expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
                            loopsToBits({{l0, t1}}));
  s = merger.optimizeSet(s);
  expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
  expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
                            loopsToBits({{l0, t0}}));
  expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
                            loopsToBits({{l0, t1}}));
}

/// Vector element-wise comparsion (disjunction) of 2 vectors, i.e.;
///   a(i) = b(i) cmp c(i)
/// which should form the 3 lattice points
/// {
///   lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) )
///   lat( i_00 / sparse_tensor_0 cmp 0)
///   lat( i_01 / 0 cmp dense_tensor_1 )
/// }
/// which should be optimized to
/// {
///   lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ) (not singleton)
///   lat( i_01 / 0 cmp dense_tensor_0 ) ()
/// }
///
/// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
/// with lat( i_00 i_01 / (sparse_tensor_0 cmp dense_tensor_1) ).
TEST_P(MergerTest3T1LD, vector_cmp) {
  const auto e = cmpiExpr(tensor(0), tensor(1));
  const auto l0 = lid(0);
  const auto t0 = tid(0);
  const auto t1 = tid(1);
  const Match &zero = synZeroMatch();
  const Match &p0 = tensorMatch(t0);
  const Match &p1 = tensorMatch(t1);
  auto s = merger.buildLattices(e, l0);
  expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
  expectLatPointWithinRange(s, 1, 2, cmpiMatch(p0, zero),
                            loopsToBits({{l0, t0}}));
  expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
                            loopsToBits({{l0, t1}}));
  s = merger.optimizeSet(s);
  expectLatPoint(s, 0, cmpiMatch(p0, p1), loopsToBits({{l0, t0}, {l0, t1}}));
  expectLatPointWithinRange(s, 1, 2, cmpiMatch(zero, p1),
                            loopsToBits({{l0, t1}}));
}