140843347SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
2*66ae1d60SPeiming Liu #include "llvm/Support/Compiler.h"
340843347SGus Smith #include "gmock/gmock.h"
440843347SGus Smith #include "gtest/gtest.h"
540843347SGus Smith #include <memory>
640843347SGus Smith 
76842ec42SRiver Riddle using namespace mlir;
840843347SGus Smith using namespace mlir::sparse_tensor;
940843347SGus Smith 
10*66ae1d60SPeiming Liu // Silence 'warning C4002: 'too many arguments for function-liked macro
11*66ae1d60SPeiming Liu //                          invocation'
12*66ae1d60SPeiming Liu // as MSVC handles ##__VA_ARGS__ differently as gcc/clang
13*66ae1d60SPeiming Liu 
14*66ae1d60SPeiming Liu #if defined(_MSC_VER) && !defined(__clang__)
15*66ae1d60SPeiming Liu #pragma warning(push)
16*66ae1d60SPeiming Liu #pragma warning(disable : 4002)
17*66ae1d60SPeiming Liu #endif
18*66ae1d60SPeiming Liu 
1940843347SGus Smith namespace {
2040843347SGus Smith 
21*66ae1d60SPeiming Liu ///
22*66ae1d60SPeiming Liu /// Defines macros to iterate binary and the combination of binary operations.
23*66ae1d60SPeiming Liu ///
24*66ae1d60SPeiming Liu 
25*66ae1d60SPeiming Liu #define FOREVERY_BINOP(DO)                                                     \
26*66ae1d60SPeiming Liu   DO(mulf, Kind::kMulF)                                                        \
27*66ae1d60SPeiming Liu   DO(mulc, Kind::kMulC)                                                        \
28*66ae1d60SPeiming Liu   DO(muli, Kind::kMulI)                                                        \
29*66ae1d60SPeiming Liu   DO(addf, Kind::kAddF)                                                        \
30*66ae1d60SPeiming Liu   DO(addc, Kind::kAddC)                                                        \
31*66ae1d60SPeiming Liu   DO(addi, Kind::kAddI)                                                        \
32*66ae1d60SPeiming Liu   DO(subf, Kind::kSubF)                                                        \
33*66ae1d60SPeiming Liu   DO(subc, Kind::kSubC)                                                        \
34*66ae1d60SPeiming Liu   DO(subi, Kind::kSubI)                                                        \
35*66ae1d60SPeiming Liu   DO(andi, Kind::kAndI)                                                        \
36*66ae1d60SPeiming Liu   DO(xori, Kind::kXorI)                                                        \
37*66ae1d60SPeiming Liu   DO(ori, Kind::kOrI)
38*66ae1d60SPeiming Liu 
39*66ae1d60SPeiming Liu // TODO: Disjunctive binary operations that need special handling are not
40*66ae1d60SPeiming Liu // included, e.g., Division are not tested (for now) as it need a constant
41*66ae1d60SPeiming Liu // non-zero dividend.
42*66ae1d60SPeiming Liu // ##__VA_ARGS__ handles cases when __VA_ARGS__ is empty.
43*66ae1d60SPeiming Liu #define FOREVERY_COMMON_DISJ_BINOP(TEST, ...)                                  \
44*66ae1d60SPeiming Liu   TEST(addf, ##__VA_ARGS__)                                                    \
45*66ae1d60SPeiming Liu   TEST(addc, ##__VA_ARGS__)                                                    \
46*66ae1d60SPeiming Liu   TEST(addi, ##__VA_ARGS__)                                                    \
47*66ae1d60SPeiming Liu   TEST(xori, ##__VA_ARGS__)                                                    \
48*66ae1d60SPeiming Liu   TEST(ori, ##__VA_ARGS__)
49*66ae1d60SPeiming Liu 
50*66ae1d60SPeiming Liu // TODO: Conjunctive binary operations that need special handling are not
51*66ae1d60SPeiming Liu // included, e.g., substraction yields a different pattern as it is mapped to
52*66ae1d60SPeiming Liu // negate operation.
53*66ae1d60SPeiming Liu #define FOREVERY_COMMON_CONJ_BINOP(TEST, ...)                                  \
54*66ae1d60SPeiming Liu   TEST(mulf, ##__VA_ARGS__)                                                    \
55*66ae1d60SPeiming Liu   TEST(mulc, ##__VA_ARGS__)                                                    \
56*66ae1d60SPeiming Liu   TEST(muli, ##__VA_ARGS__)                                                    \
57*66ae1d60SPeiming Liu   TEST(andi, ##__VA_ARGS__)
58*66ae1d60SPeiming Liu 
59*66ae1d60SPeiming Liu #define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST)                          \
60*66ae1d60SPeiming Liu   FOREVERY_COMMON_CONJ_BINOP(TEST, addf)                                       \
61*66ae1d60SPeiming Liu   FOREVERY_COMMON_CONJ_BINOP(TEST, addc)                                       \
62*66ae1d60SPeiming Liu   FOREVERY_COMMON_CONJ_BINOP(TEST, addi)                                       \
63*66ae1d60SPeiming Liu   FOREVERY_COMMON_CONJ_BINOP(TEST, xori)                                       \
64*66ae1d60SPeiming Liu   FOREVERY_COMMON_CONJ_BINOP(TEST, ori)
65*66ae1d60SPeiming Liu 
66*66ae1d60SPeiming Liu #define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST)                          \
67*66ae1d60SPeiming Liu   FOREVERY_COMMON_CONJ_BINOP(TEST, mulf)                                       \
68*66ae1d60SPeiming Liu   FOREVERY_COMMON_CONJ_BINOP(TEST, mulc)                                       \
69*66ae1d60SPeiming Liu   FOREVERY_COMMON_CONJ_BINOP(TEST, muli)                                       \
70*66ae1d60SPeiming Liu   FOREVERY_COMMON_CONJ_BINOP(TEST, andi)
71*66ae1d60SPeiming Liu 
72*66ae1d60SPeiming Liu #define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST)                          \
73*66ae1d60SPeiming Liu   FOREVERY_COMMON_DISJ_BINOP(TEST, addf)                                       \
74*66ae1d60SPeiming Liu   FOREVERY_COMMON_DISJ_BINOP(TEST, addc)                                       \
75*66ae1d60SPeiming Liu   FOREVERY_COMMON_DISJ_BINOP(TEST, addi)                                       \
76*66ae1d60SPeiming Liu   FOREVERY_COMMON_DISJ_BINOP(TEST, ori)                                        \
77*66ae1d60SPeiming Liu   FOREVERY_COMMON_DISJ_BINOP(TEST, xori)
78*66ae1d60SPeiming Liu 
79*66ae1d60SPeiming Liu ///
80*66ae1d60SPeiming Liu /// Helper classes/functions for testing Merger.
81*66ae1d60SPeiming Liu ///
82*66ae1d60SPeiming Liu 
8340843347SGus Smith /// Simple recursive data structure used to match expressions in Mergers.
8440843347SGus Smith struct Pattern {
8540843347SGus Smith   Kind kind;
8640843347SGus Smith 
8740843347SGus Smith   /// Expressions representing tensors simply have a tensor number.
8840843347SGus Smith   unsigned tensorNum;
8940843347SGus Smith 
9040843347SGus Smith   /// Tensor operations point to their children.
9140843347SGus Smith   std::shared_ptr<Pattern> e0;
9240843347SGus Smith   std::shared_ptr<Pattern> e1;
9340843347SGus Smith 
9440843347SGus Smith   /// Constructors.
9540843347SGus Smith   /// Rather than using these, please use the readable helper constructor
9640843347SGus Smith   /// functions below to make tests more readable.
Pattern__anonb1ca32140111::Pattern9740843347SGus Smith   Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
Pattern__anonb1ca32140111::Pattern981fc096afSMehdi Amini   Pattern(Kind kind, const std::shared_ptr<Pattern> &e0,
991fc096afSMehdi Amini           const std::shared_ptr<Pattern> &e1)
10040843347SGus Smith       : kind(kind), e0(e0), e1(e1) {
10140843347SGus Smith     assert(kind >= Kind::kMulF);
10240843347SGus Smith     assert(e0 && e1);
10340843347SGus Smith   }
10440843347SGus Smith };
10540843347SGus Smith 
10640843347SGus Smith ///
10740843347SGus Smith /// Readable Pattern builder functions.
10840843347SGus Smith /// These should be preferred over the actual constructors.
10940843347SGus Smith ///
11040843347SGus Smith 
tensorPattern(unsigned tensorNum)11140843347SGus Smith static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
11240843347SGus Smith   return std::make_shared<Pattern>(tensorNum);
11340843347SGus Smith }
11440843347SGus Smith 
115*66ae1d60SPeiming Liu #define IMPL_BINOP_PATTERN(OP, KIND)                                           \
116*66ae1d60SPeiming Liu   LLVM_ATTRIBUTE_UNUSED static std::shared_ptr<Pattern> OP##Pattern(           \
117*66ae1d60SPeiming Liu       const std::shared_ptr<Pattern> &e0,                                      \
118*66ae1d60SPeiming Liu       const std::shared_ptr<Pattern> &e1) {                                    \
119*66ae1d60SPeiming Liu     return std::make_shared<Pattern>(KIND, e0, e1);                            \
12040843347SGus Smith   }
12140843347SGus Smith 
122*66ae1d60SPeiming Liu FOREVERY_BINOP(IMPL_BINOP_PATTERN)
123*66ae1d60SPeiming Liu 
124*66ae1d60SPeiming Liu #undef IMPL_BINOP_PATTERN
12540843347SGus Smith 
12640843347SGus Smith class MergerTestBase : public ::testing::Test {
12740843347SGus Smith protected:
MergerTestBase(unsigned numTensors,unsigned numLoops)12840843347SGus Smith   MergerTestBase(unsigned numTensors, unsigned numLoops)
12940843347SGus Smith       : numTensors(numTensors), numLoops(numLoops),
13040843347SGus Smith         merger(numTensors, numLoops) {}
13140843347SGus Smith 
13240843347SGus Smith   ///
13340843347SGus Smith   /// Expression construction helpers.
13440843347SGus Smith   ///
13540843347SGus Smith 
tensor(unsigned tensor)13640843347SGus Smith   unsigned tensor(unsigned tensor) {
13740843347SGus Smith     return merger.addExp(Kind::kTensor, tensor);
13840843347SGus Smith   }
13940843347SGus Smith 
140*66ae1d60SPeiming Liu #define IMPL_BINOP_EXPR(OP, KIND)                                              \
141*66ae1d60SPeiming Liu   LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) {          \
142*66ae1d60SPeiming Liu     return merger.addExp(KIND, e0, e1);                                        \
14340843347SGus Smith   }
14440843347SGus Smith 
FOREVERY_BINOP(IMPL_BINOP_EXPR)145*66ae1d60SPeiming Liu   FOREVERY_BINOP(IMPL_BINOP_EXPR)
146*66ae1d60SPeiming Liu 
147*66ae1d60SPeiming Liu #undef IMPL_BINOP_EXPR
14840843347SGus Smith 
14940843347SGus Smith   ///
15040843347SGus Smith   /// Comparison helpers.
15140843347SGus Smith   ///
15240843347SGus Smith 
15340843347SGus Smith   /// For readability of tests.
15440843347SGus Smith   unsigned lat(unsigned lat) { return lat; }
15540843347SGus Smith 
15640843347SGus Smith   /// Returns true if a lattice point with an expression matching the given
15740843347SGus Smith   /// pattern and bits matching the given bits is present in lattice points
15840843347SGus Smith   /// [p, p+n) of lattice set s. This is useful for testing partial ordering
15940843347SGus Smith   /// constraints between lattice points. We generally know how contiguous
16040843347SGus Smith   /// groups of lattice points should be ordered with respect to other groups,
16140843347SGus Smith   /// but there is no required ordering within groups.
162*66ae1d60SPeiming Liu   /// If simple is true, then compare the lat.simple field instead to test the
163*66ae1d60SPeiming Liu   /// result after optimization
latPointWithinRange(unsigned s,unsigned p,unsigned n,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple)16440843347SGus Smith   bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
1651fc096afSMehdi Amini                            const std::shared_ptr<Pattern> &pattern,
166*66ae1d60SPeiming Liu                            const BitVector &bits, bool simple) {
16740843347SGus Smith     for (unsigned i = p; i < p + n; ++i) {
16840843347SGus Smith       if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
169*66ae1d60SPeiming Liu           compareBits(s, i, bits, simple))
17040843347SGus Smith         return true;
17140843347SGus Smith     }
17240843347SGus Smith     return false;
17340843347SGus Smith   }
17440843347SGus Smith 
17540843347SGus Smith   /// Wrapper over latPointWithinRange for readability of tests.
expectLatPointWithinRange(unsigned s,unsigned p,unsigned n,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple=false)17640843347SGus Smith   void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
1774f415216SMehdi Amini                                  const std::shared_ptr<Pattern> &pattern,
178*66ae1d60SPeiming Liu                                  const BitVector &bits, bool simple = false) {
179*66ae1d60SPeiming Liu     EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple));
18040843347SGus Smith   }
18140843347SGus Smith 
18240843347SGus Smith   /// Wrapper over expectLatPointWithinRange for a single lat point.
expectLatPoint(unsigned s,unsigned p,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple=false)1834f415216SMehdi Amini   void expectLatPoint(unsigned s, unsigned p,
1844f415216SMehdi Amini                       const std::shared_ptr<Pattern> &pattern,
185*66ae1d60SPeiming Liu                       const BitVector &bits, bool simple = false) {
186*66ae1d60SPeiming Liu     EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple));
18740843347SGus Smith   }
18840843347SGus Smith 
18940843347SGus Smith   /// Converts a vector of (loop, tensor) pairs to a bitvector with the
19040843347SGus Smith   /// corresponding bits set.
191d10d49dcSRiver Riddle   BitVector
loopsToBits(const std::vector<std::pair<unsigned,unsigned>> & loops)1921fc096afSMehdi Amini   loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) {
193d10d49dcSRiver Riddle     BitVector testBits = BitVector(numTensors + 1, false);
19440843347SGus Smith     for (auto l : loops) {
19540843347SGus Smith       auto loop = std::get<0>(l);
19640843347SGus Smith       auto tensor = std::get<1>(l);
19740843347SGus Smith       testBits.set(numTensors * loop + tensor);
19840843347SGus Smith     }
19940843347SGus Smith     return testBits;
20040843347SGus Smith   }
20140843347SGus Smith 
20240843347SGus Smith   /// Returns true if the bits of lattice point p in set s match the given bits.
203*66ae1d60SPeiming Liu   /// If simple is true, then compare the lat.simple field instead to test the
204*66ae1d60SPeiming Liu   /// result after optimization
compareBits(unsigned s,unsigned p,const BitVector & bits,bool simple)205*66ae1d60SPeiming Liu   bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) {
206*66ae1d60SPeiming Liu     if (simple)
207*66ae1d60SPeiming Liu       return merger.lat(merger.set(s)[p]).simple == bits;
20840843347SGus Smith     return merger.lat(merger.set(s)[p]).bits == bits;
20940843347SGus Smith   }
21040843347SGus Smith 
21140843347SGus Smith   /// Check that there are n lattice points in set s.
expectNumLatPoints(unsigned s,unsigned n)21240843347SGus Smith   void expectNumLatPoints(unsigned s, unsigned n) {
21340843347SGus Smith     EXPECT_THAT(merger.set(s).size(), n);
21440843347SGus Smith   }
21540843347SGus Smith 
21640843347SGus Smith   /// Compares expressions for equality. Equality is defined recursively as:
21706aa6ec8SAart Bik   /// - Operations are equal if they have the same kind and children.
21806aa6ec8SAart Bik   /// - Leaf tensors are equal if they refer to the same tensor.
compareExpression(unsigned e,const std::shared_ptr<Pattern> & pattern)2191fc096afSMehdi Amini   bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
22040843347SGus Smith     auto tensorExp = merger.exp(e);
22140843347SGus Smith     if (tensorExp.kind != pattern->kind)
22240843347SGus Smith       return false;
22340843347SGus Smith     switch (tensorExp.kind) {
22406aa6ec8SAart Bik     // Leaf.
22506aa6ec8SAart Bik     case kTensor:
22640843347SGus Smith       return tensorExp.tensor == pattern->tensorNum;
22706aa6ec8SAart Bik     case kInvariant:
22806aa6ec8SAart Bik     case kIndex:
22906aa6ec8SAart Bik       llvm_unreachable("invariant not handled yet");
23006aa6ec8SAart Bik     // Unary operations.
23106aa6ec8SAart Bik     case kAbsF:
23206aa6ec8SAart Bik     case kAbsC:
23306aa6ec8SAart Bik     case kCeilF:
23406aa6ec8SAart Bik     case kFloorF:
23506aa6ec8SAart Bik     case kSqrtF:
23606aa6ec8SAart Bik     case kSqrtC:
23706aa6ec8SAart Bik     case kExpm1F:
23806aa6ec8SAart Bik     case kExpm1C:
23906aa6ec8SAart Bik     case kLog1pF:
24006aa6ec8SAart Bik     case kLog1pC:
24106aa6ec8SAart Bik     case kSinF:
24206aa6ec8SAart Bik     case kSinC:
24306aa6ec8SAart Bik     case kTanhF:
24406aa6ec8SAart Bik     case kTanhC:
24506aa6ec8SAart Bik     case kNegF:
24606aa6ec8SAart Bik     case kNegC:
24706aa6ec8SAart Bik     case kNegI:
24806aa6ec8SAart Bik     case kTruncF:
24906aa6ec8SAart Bik     case kExtF:
25006aa6ec8SAart Bik     case kCastFS:
25106aa6ec8SAart Bik     case kCastFU:
25206aa6ec8SAart Bik     case kCastSF:
25306aa6ec8SAart Bik     case kCastUF:
25406aa6ec8SAart Bik     case kCastS:
25506aa6ec8SAart Bik     case kCastU:
25606aa6ec8SAart Bik     case kCastIdx:
25706aa6ec8SAart Bik     case kTruncI:
25806aa6ec8SAart Bik     case kCIm:
25906aa6ec8SAart Bik     case kCRe:
26006aa6ec8SAart Bik     case kBitCast:
26106aa6ec8SAart Bik     case kBinaryBranch:
26206aa6ec8SAart Bik     case kUnary:
26306aa6ec8SAart Bik     case kShlI:
26406aa6ec8SAart Bik     case kBinary:
265123e8dfcSAart Bik       return compareExpression(tensorExp.children.e0, pattern->e0);
26606aa6ec8SAart Bik     // Binary operations.
26706aa6ec8SAart Bik     case kMulF:
26806aa6ec8SAart Bik     case kMulC:
26906aa6ec8SAart Bik     case kMulI:
27006aa6ec8SAart Bik     case kDivF:
27106aa6ec8SAart Bik     case kDivC:
27206aa6ec8SAart Bik     case kDivS:
27306aa6ec8SAart Bik     case kDivU:
27406aa6ec8SAart Bik     case kAddF:
27506aa6ec8SAart Bik     case kAddC:
27606aa6ec8SAart Bik     case kAddI:
27706aa6ec8SAart Bik     case kSubF:
27806aa6ec8SAart Bik     case kSubC:
27906aa6ec8SAart Bik     case kSubI:
28006aa6ec8SAart Bik     case kAndI:
28106aa6ec8SAart Bik     case kOrI:
28206aa6ec8SAart Bik     case kXorI:
28306aa6ec8SAart Bik     case kShrS:
28406aa6ec8SAart Bik     case kShrU:
28540843347SGus Smith       return compareExpression(tensorExp.children.e0, pattern->e0) &&
28640843347SGus Smith              compareExpression(tensorExp.children.e1, pattern->e1);
28740843347SGus Smith     }
28806aa6ec8SAart Bik     llvm_unreachable("unexpected kind");
28940843347SGus Smith   }
29040843347SGus Smith 
29140843347SGus Smith   unsigned numTensors;
29240843347SGus Smith   unsigned numLoops;
29340843347SGus Smith   Merger merger;
29440843347SGus Smith };
29540843347SGus Smith 
296*66ae1d60SPeiming Liu ///
297*66ae1d60SPeiming Liu /// Tests with all sparse inputs.
298*66ae1d60SPeiming Liu ///
299*66ae1d60SPeiming Liu 
30040843347SGus Smith class MergerTest3T1L : public MergerTestBase {
30140843347SGus Smith protected:
30240843347SGus Smith   // Our three tensors (two inputs, one output).
30340843347SGus Smith   const unsigned t0 = 0, t1 = 1, t2 = 2;
30440843347SGus Smith 
30540843347SGus Smith   // Our single loop.
30640843347SGus Smith   const unsigned l0 = 0;
30740843347SGus Smith 
MergerTest3T1L()30840843347SGus Smith   MergerTest3T1L() : MergerTestBase(3, 1) {
30940843347SGus Smith     // Tensor 0: sparse input vector.
31040843347SGus Smith     merger.addExp(Kind::kTensor, t0, -1u);
31140843347SGus Smith     merger.setDim(t0, l0, Dim::kSparse);
31240843347SGus Smith 
31340843347SGus Smith     // Tensor 1: sparse input vector.
31440843347SGus Smith     merger.addExp(Kind::kTensor, t1, -1u);
31540843347SGus Smith     merger.setDim(t1, l0, Dim::kSparse);
31640843347SGus Smith 
31740843347SGus Smith     // Tensor 2: dense output vector.
31840843347SGus Smith     merger.addExp(Kind::kTensor, t2, -1u);
31940843347SGus Smith     merger.setDim(t2, l0, Dim::kDense);
32040843347SGus Smith   }
32140843347SGus Smith };
32240843347SGus Smith 
323*66ae1d60SPeiming Liu class MergerTest4T1L : public MergerTestBase {
324*66ae1d60SPeiming Liu protected:
325*66ae1d60SPeiming Liu   // Our four tensors (three inputs, one output).
326*66ae1d60SPeiming Liu   const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
327*66ae1d60SPeiming Liu 
328*66ae1d60SPeiming Liu   // Our single loop.
329*66ae1d60SPeiming Liu   const unsigned l0 = 0;
330*66ae1d60SPeiming Liu 
MergerTest4T1L()331*66ae1d60SPeiming Liu   MergerTest4T1L() : MergerTestBase(4, 1) {
332*66ae1d60SPeiming Liu     // Tensor 0: sparse input vector.
333*66ae1d60SPeiming Liu     merger.addExp(Kind::kTensor, t0, -1u);
334*66ae1d60SPeiming Liu     merger.setDim(t0, l0, Dim::kSparse);
335*66ae1d60SPeiming Liu 
336*66ae1d60SPeiming Liu     // Tensor 1: sparse input vector.
337*66ae1d60SPeiming Liu     merger.addExp(Kind::kTensor, t1, -1u);
338*66ae1d60SPeiming Liu     merger.setDim(t1, l0, Dim::kSparse);
339*66ae1d60SPeiming Liu 
340*66ae1d60SPeiming Liu     // Tensor 2: sparse input vector
341*66ae1d60SPeiming Liu     merger.addExp(Kind::kTensor, t2, -1u);
342*66ae1d60SPeiming Liu     merger.setDim(t2, l0, Dim::kSparse);
343*66ae1d60SPeiming Liu 
344*66ae1d60SPeiming Liu     // Tensor 3: dense output vector
345*66ae1d60SPeiming Liu     merger.addExp(Kind::kTensor, t3, -1u);
346*66ae1d60SPeiming Liu     merger.setDim(t3, l0, Dim::kDense);
347*66ae1d60SPeiming Liu   }
348*66ae1d60SPeiming Liu };
349*66ae1d60SPeiming Liu 
350*66ae1d60SPeiming Liu ///
351*66ae1d60SPeiming Liu /// Tests with both sparse and dense input.
352*66ae1d60SPeiming Liu ///
353*66ae1d60SPeiming Liu 
354*66ae1d60SPeiming Liu class MergerTest3T1LD : public MergerTestBase {
355*66ae1d60SPeiming Liu protected:
356*66ae1d60SPeiming Liu   // Our three tensors (two inputs, one output).
357*66ae1d60SPeiming Liu   const unsigned t0 = 0, t1 = 1, t2 = 2;
358*66ae1d60SPeiming Liu 
359*66ae1d60SPeiming Liu   // Our single loop.
360*66ae1d60SPeiming Liu   const unsigned l0 = 0;
361*66ae1d60SPeiming Liu 
MergerTest3T1LD()362*66ae1d60SPeiming Liu   MergerTest3T1LD() : MergerTestBase(3, 1) {
363*66ae1d60SPeiming Liu     // Tensor 0: sparse input vector.
364*66ae1d60SPeiming Liu     merger.addExp(Kind::kTensor, t0, -1u);
365*66ae1d60SPeiming Liu     merger.setDim(t0, l0, Dim::kSparse);
366*66ae1d60SPeiming Liu 
367*66ae1d60SPeiming Liu     // Tensor 1: dense input vector.
368*66ae1d60SPeiming Liu     merger.addExp(Kind::kTensor, t1, -1u);
369*66ae1d60SPeiming Liu     merger.setDim(t1, l0, Dim::kDense);
370*66ae1d60SPeiming Liu 
371*66ae1d60SPeiming Liu     // Tensor 2: dense output vector.
372*66ae1d60SPeiming Liu     merger.addExp(Kind::kTensor, t2, -1u);
373*66ae1d60SPeiming Liu     merger.setDim(t2, l0, Dim::kDense);
374*66ae1d60SPeiming Liu   }
375*66ae1d60SPeiming Liu };
376*66ae1d60SPeiming Liu 
377be0a7e9fSMehdi Amini } // namespace
37840843347SGus Smith 
379*66ae1d60SPeiming Liu /// Vector addition (disjunction) of 2 vectors. i.e.;
38040843347SGus Smith ///   a(i) = b(i) + c(i)
38140843347SGus Smith /// which should form the 3 lattice points
38240843347SGus Smith /// {
38340843347SGus Smith ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
38440843347SGus Smith ///   lat( i_00 / tensor_0 )
38540843347SGus Smith ///   lat( i_01 / tensor_1 )
38640843347SGus Smith /// }
387*66ae1d60SPeiming Liu /// and after optimization, the lattice points do not change (as there is no
388*66ae1d60SPeiming Liu /// duplicated point and all input vectors are sparse vector).
38940843347SGus Smith /// {
39040843347SGus Smith ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
39140843347SGus Smith ///   lat( i_00 / tensor_0 )
392*66ae1d60SPeiming Liu ///   lat( i_01 / tensor_1 )
39340843347SGus Smith /// }
394*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_DISJ(OP)                                              \
395*66ae1d60SPeiming Liu   TEST_F(MergerTest3T1L, vector_##OP) {                                        \
396*66ae1d60SPeiming Liu     auto e = OP##Expr(tensor(t0), tensor(t1));                                 \
397*66ae1d60SPeiming Liu     auto p0 = tensorPattern(t0);                                               \
398*66ae1d60SPeiming Liu     auto p1 = tensorPattern(t1);                                               \
399*66ae1d60SPeiming Liu     auto s = merger.buildLattices(e, l0);                                      \
400*66ae1d60SPeiming Liu                                                                                \
401*66ae1d60SPeiming Liu     expectNumLatPoints(s, 3);                                                  \
402*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
403*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
404*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}));      \
405*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}));      \
406*66ae1d60SPeiming Liu                                                                                \
407*66ae1d60SPeiming Liu     s = merger.optimizeSet(s);                                                 \
408*66ae1d60SPeiming Liu     expectNumLatPoints(s, 3);                                                  \
409*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
410*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
411*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}),       \
412*66ae1d60SPeiming Liu                               true);                                           \
413*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}),       \
414*66ae1d60SPeiming Liu                               true);                                           \
41540843347SGus Smith   }
41640843347SGus Smith 
417*66ae1d60SPeiming Liu FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
418*66ae1d60SPeiming Liu 
419*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_DISJ
420*66ae1d60SPeiming Liu 
421*66ae1d60SPeiming Liu /// Vector multiplication (conjunction) of 2 vectors, i.e.;
42240843347SGus Smith ///   a(i) = b(i) * c(i)
42340843347SGus Smith /// which should form the single lattice point
42440843347SGus Smith /// {
42540843347SGus Smith ///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
42640843347SGus Smith /// }
427*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_CONJ(OP)                                              \
428*66ae1d60SPeiming Liu   TEST_F(MergerTest3T1L, vector_##OP) {                                        \
429*66ae1d60SPeiming Liu     auto e = OP##Expr(t0, t1);                                                 \
430*66ae1d60SPeiming Liu     auto p0 = tensorPattern(t0);                                               \
431*66ae1d60SPeiming Liu     auto p1 = tensorPattern(t1);                                               \
432*66ae1d60SPeiming Liu     auto s = merger.buildLattices(e, l0);                                      \
433*66ae1d60SPeiming Liu                                                                                \
434*66ae1d60SPeiming Liu     expectNumLatPoints(s, 1);                                                  \
435*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
436*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
437*66ae1d60SPeiming Liu                                                                                \
438*66ae1d60SPeiming Liu     s = merger.optimizeSet(s);                                                 \
439*66ae1d60SPeiming Liu     expectNumLatPoints(s, 1);                                                  \
440*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
441*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
44240843347SGus Smith   }
443*66ae1d60SPeiming Liu 
444*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
445*66ae1d60SPeiming Liu 
446*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_CONJ
447*66ae1d60SPeiming Liu 
448*66ae1d60SPeiming Liu /// Vector multiplication (conjunction) then addition (disjunction), i.e.;
449*66ae1d60SPeiming Liu ///   a(i) = b(i) * c(i) + d(i);
450*66ae1d60SPeiming Liu /// which should form
451*66ae1d60SPeiming Liu /// {
452*66ae1d60SPeiming Liu ///    lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
453*66ae1d60SPeiming Liu ///    lat( i_00 i_01 / tensor_0 * tensor_1
454*66ae1d60SPeiming Liu ///    lat( i_02 / tensor_2 )
455*66ae1d60SPeiming Liu /// }
456*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ)                                 \
457*66ae1d60SPeiming Liu   TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) {                             \
458*66ae1d60SPeiming Liu     auto em = CONJ##Expr(t0, t1);                                              \
459*66ae1d60SPeiming Liu     auto e = DISJ##Expr(em, t2);                                               \
460*66ae1d60SPeiming Liu     auto p0 = tensorPattern(t0);                                               \
461*66ae1d60SPeiming Liu     auto p1 = tensorPattern(t1);                                               \
462*66ae1d60SPeiming Liu     auto p2 = tensorPattern(t2);                                               \
463*66ae1d60SPeiming Liu     auto s = merger.buildLattices(e, l0);                                      \
464*66ae1d60SPeiming Liu                                                                                \
465*66ae1d60SPeiming Liu     expectNumLatPoints(s, 3);                                                  \
466*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2),        \
467*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
468*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1),             \
469*66ae1d60SPeiming Liu                               loopsToBits({{l0, t0}, {l0, t1}}));              \
470*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}}));      \
471*66ae1d60SPeiming Liu                                                                                \
472*66ae1d60SPeiming Liu     s = merger.optimizeSet(s);                                                 \
473*66ae1d60SPeiming Liu     expectNumLatPoints(s, 3);                                                  \
474*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2),        \
475*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
476*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1),             \
477*66ae1d60SPeiming Liu                               loopsToBits({{l0, t0}, {l0, t1}}));              \
478*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}}));      \
479*66ae1d60SPeiming Liu   }
480*66ae1d60SPeiming Liu 
481*66ae1d60SPeiming Liu FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
482*66ae1d60SPeiming Liu 
483*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_CONJ_DISJ
484*66ae1d60SPeiming Liu 
485*66ae1d60SPeiming Liu /// Vector addition (disjunction) then addition (disjunction), i.e.;
486*66ae1d60SPeiming Liu ///   a(i) = b(i) + c(i) + d(i)
487*66ae1d60SPeiming Liu /// which should form
488*66ae1d60SPeiming Liu /// {
489*66ae1d60SPeiming Liu ///   lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
490*66ae1d60SPeiming Liu ///   lat( i_02 i_01 / tensor_2 + tensor_1 )
491*66ae1d60SPeiming Liu ///   lat( i_02 i_00 / tensor_2 + tensor_0 )
492*66ae1d60SPeiming Liu ///   lat( i_01 i_00 / tensor_1 + tensor_0 )
493*66ae1d60SPeiming Liu ///   lat( i_02 / tensor_2 )
494*66ae1d60SPeiming Liu ///   lat( i_01 / tensor_1 )
495*66ae1d60SPeiming Liu ///   lat( i_00 / tensor_0 )
496*66ae1d60SPeiming Liu /// }
497*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2)                               \
498*66ae1d60SPeiming Liu   TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) {                           \
499*66ae1d60SPeiming Liu     auto em = DISJ1##Expr(t0, t1);                                             \
500*66ae1d60SPeiming Liu     auto e = DISJ2##Expr(em, t2);                                              \
501*66ae1d60SPeiming Liu     auto p0 = tensorPattern(t0);                                               \
502*66ae1d60SPeiming Liu     auto p1 = tensorPattern(t1);                                               \
503*66ae1d60SPeiming Liu     auto p2 = tensorPattern(t2);                                               \
504*66ae1d60SPeiming Liu     auto s = merger.buildLattices(e, l0);                                      \
505*66ae1d60SPeiming Liu                                                                                \
506*66ae1d60SPeiming Liu     expectNumLatPoints(s, 7);                                                  \
507*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),      \
508*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
509*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2),            \
510*66ae1d60SPeiming Liu                               loopsToBits({{l0, t1}, {l0, t2}}));              \
511*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2),            \
512*66ae1d60SPeiming Liu                               loopsToBits({{l0, t0}, {l0, t2}}));              \
513*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1),            \
514*66ae1d60SPeiming Liu                               loopsToBits({{l0, t0}, {l0, t1}}));              \
515*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}}));      \
516*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}}));      \
517*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}}));      \
518*66ae1d60SPeiming Liu                                                                                \
519*66ae1d60SPeiming Liu     s = merger.optimizeSet(s);                                                 \
520*66ae1d60SPeiming Liu     expectNumLatPoints(s, 7);                                                  \
521*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),      \
522*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
523*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2),            \
524*66ae1d60SPeiming Liu                               loopsToBits({{l0, t1}, {l0, t2}}));              \
525*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2),            \
526*66ae1d60SPeiming Liu                               loopsToBits({{l0, t0}, {l0, t2}}));              \
527*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1),            \
528*66ae1d60SPeiming Liu                               loopsToBits({{l0, t0}, {l0, t1}}));              \
529*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}}));      \
530*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}}));      \
531*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}}));      \
532*66ae1d60SPeiming Liu   }
533*66ae1d60SPeiming Liu 
534*66ae1d60SPeiming Liu FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
535*66ae1d60SPeiming Liu 
536*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_DISJ_DISJ
537*66ae1d60SPeiming Liu 
538*66ae1d60SPeiming Liu /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
539*66ae1d60SPeiming Liu ///   a(i) = b(i) * c(i) * d(i);
540*66ae1d60SPeiming Liu /// which should form
541*66ae1d60SPeiming Liu /// {
542*66ae1d60SPeiming Liu ///    lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
543*66ae1d60SPeiming Liu /// }
544*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2)                               \
545*66ae1d60SPeiming Liu   TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) {                           \
546*66ae1d60SPeiming Liu     auto em = CONJ1##Expr(t0, t1);                                             \
547*66ae1d60SPeiming Liu     auto e = CONJ2##Expr(em, t2);                                              \
548*66ae1d60SPeiming Liu     auto p0 = tensorPattern(t0);                                               \
549*66ae1d60SPeiming Liu     auto p1 = tensorPattern(t1);                                               \
550*66ae1d60SPeiming Liu     auto p2 = tensorPattern(t2);                                               \
551*66ae1d60SPeiming Liu     auto s = merger.buildLattices(e, l0);                                      \
552*66ae1d60SPeiming Liu     expectNumLatPoints(s, 1);                                                  \
553*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
554*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
555*66ae1d60SPeiming Liu     s = merger.optimizeSet(s);                                                 \
556*66ae1d60SPeiming Liu     expectNumLatPoints(s, 1);                                                  \
557*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
558*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true);         \
559*66ae1d60SPeiming Liu   }
560*66ae1d60SPeiming Liu 
561*66ae1d60SPeiming Liu FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
562*66ae1d60SPeiming Liu 
563*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_CONJ_CONJ
564*66ae1d60SPeiming Liu 
565*66ae1d60SPeiming Liu /// Vector addition (disjunction) of 2 vectors, i.e.;
566*66ae1d60SPeiming Liu ///   a(i) = b(i) + c(i)
567*66ae1d60SPeiming Liu /// which should form the 3 lattice points
568*66ae1d60SPeiming Liu /// {
569*66ae1d60SPeiming Liu ///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
570*66ae1d60SPeiming Liu ///   lat( i_00 / sparse_tensor_0 )
571*66ae1d60SPeiming Liu ///   lat( i_01 / dense_tensor_1 )
572*66ae1d60SPeiming Liu /// }
573*66ae1d60SPeiming Liu /// which should be optimized to
574*66ae1d60SPeiming Liu /// {
575*66ae1d60SPeiming Liu ///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
576*66ae1d60SPeiming Liu ///   lat( i_01 / dense_tensor_0 ) (no sparse dimension)
577*66ae1d60SPeiming Liu /// }
578*66ae1d60SPeiming Liu ///
579*66ae1d60SPeiming Liu /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
580*66ae1d60SPeiming Liu /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
581*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP)                                    \
582*66ae1d60SPeiming Liu   TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
583*66ae1d60SPeiming Liu     auto e = OP##Expr(tensor(t0), tensor(t1));                                 \
584*66ae1d60SPeiming Liu     auto p0 = tensorPattern(t0);                                               \
585*66ae1d60SPeiming Liu     auto p1 = tensorPattern(t1);                                               \
586*66ae1d60SPeiming Liu     auto s = merger.buildLattices(e, l0);                                      \
587*66ae1d60SPeiming Liu                                                                                \
588*66ae1d60SPeiming Liu     expectNumLatPoints(s, 3);                                                  \
589*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
590*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
591*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}));      \
592*66ae1d60SPeiming Liu     expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}));      \
593*66ae1d60SPeiming Liu                                                                                \
594*66ae1d60SPeiming Liu     s = merger.optimizeSet(s);                                                 \
595*66ae1d60SPeiming Liu     expectNumLatPoints(s, 2);                                                  \
596*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
597*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
598*66ae1d60SPeiming Liu     expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true);              \
599*66ae1d60SPeiming Liu   }
600*66ae1d60SPeiming Liu 
601*66ae1d60SPeiming Liu FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
602*66ae1d60SPeiming Liu 
603*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
604*66ae1d60SPeiming Liu 
605*66ae1d60SPeiming Liu /// Vector multiplication (conjunction) of 2 vectors, i.e.:
606*66ae1d60SPeiming Liu ///   a(i) = b(i) * c(i)
607*66ae1d60SPeiming Liu /// which should form the single lattice point
608*66ae1d60SPeiming Liu /// {
609*66ae1d60SPeiming Liu ///   lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
610*66ae1d60SPeiming Liu /// }
611*66ae1d60SPeiming Liu /// it should be optimized to
612*66ae1d60SPeiming Liu /// {
613*66ae1d60SPeiming Liu ///   lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
614*66ae1d60SPeiming Liu /// }
615*66ae1d60SPeiming Liu /// since i_01 is a dense dimension.
616*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP)                                    \
617*66ae1d60SPeiming Liu   TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
618*66ae1d60SPeiming Liu     auto e = OP##Expr(t0, t1);                                                 \
619*66ae1d60SPeiming Liu     auto p0 = tensorPattern(t0);                                               \
620*66ae1d60SPeiming Liu     auto p1 = tensorPattern(t1);                                               \
621*66ae1d60SPeiming Liu     auto s = merger.buildLattices(e, l0);                                      \
622*66ae1d60SPeiming Liu                                                                                \
623*66ae1d60SPeiming Liu     expectNumLatPoints(s, 1);                                                  \
624*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
625*66ae1d60SPeiming Liu                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
626*66ae1d60SPeiming Liu                                                                                \
627*66ae1d60SPeiming Liu     s = merger.optimizeSet(s);                                                 \
628*66ae1d60SPeiming Liu     expectNumLatPoints(s, 1);                                                  \
629*66ae1d60SPeiming Liu     expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}),    \
630*66ae1d60SPeiming Liu                    true);                                                      \
631*66ae1d60SPeiming Liu   }
632*66ae1d60SPeiming Liu 
633*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
634*66ae1d60SPeiming Liu 
635*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
636*66ae1d60SPeiming Liu 
637*66ae1d60SPeiming Liu // TODO: mult-dim tests
638*66ae1d60SPeiming Liu 
639*66ae1d60SPeiming Liu // restore warning status
640*66ae1d60SPeiming Liu #if defined(_MSC_VER) && !defined(__clang__)
641*66ae1d60SPeiming Liu #pragma warning(pop)
642*66ae1d60SPeiming Liu #endif
643