1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
2 #include "llvm/Support/Compiler.h"
3 #include "gmock/gmock.h"
4 #include "gtest/gtest.h"
5 #include <memory>
6 
7 using namespace mlir;
8 using namespace mlir::sparse_tensor;
9 
10 // Silence 'warning C4002: 'too many arguments for function-liked macro
11 //                          invocation'
12 // as MSVC handles ##__VA_ARGS__ differently as gcc/clang
13 
14 #if defined(_MSC_VER) && !defined(__clang__)
15 #pragma warning(push)
16 #pragma warning(disable : 4002)
17 #endif
18 
19 namespace {
20 
21 ///
22 /// Defines macros to iterate binary and the combination of binary operations.
23 ///
24 
25 #define FOREVERY_BINOP(DO)                                                     \
26   DO(mulf, Kind::kMulF)                                                        \
27   DO(mulc, Kind::kMulC)                                                        \
28   DO(muli, Kind::kMulI)                                                        \
29   DO(addf, Kind::kAddF)                                                        \
30   DO(addc, Kind::kAddC)                                                        \
31   DO(addi, Kind::kAddI)                                                        \
32   DO(subf, Kind::kSubF)                                                        \
33   DO(subc, Kind::kSubC)                                                        \
34   DO(subi, Kind::kSubI)                                                        \
35   DO(andi, Kind::kAndI)                                                        \
36   DO(xori, Kind::kXorI)                                                        \
37   DO(ori, Kind::kOrI)
38 
39 // TODO: Disjunctive binary operations that need special handling are not
40 // included, e.g., Division are not tested (for now) as it need a constant
41 // non-zero dividend.
42 // ##__VA_ARGS__ handles cases when __VA_ARGS__ is empty.
43 #define FOREVERY_COMMON_DISJ_BINOP(TEST, ...)                                  \
44   TEST(addf, ##__VA_ARGS__)                                                    \
45   TEST(addc, ##__VA_ARGS__)                                                    \
46   TEST(addi, ##__VA_ARGS__)                                                    \
47   TEST(xori, ##__VA_ARGS__)                                                    \
48   TEST(ori, ##__VA_ARGS__)
49 
50 // TODO: Conjunctive binary operations that need special handling are not
51 // included, e.g., substraction yields a different pattern as it is mapped to
52 // negate operation.
53 #define FOREVERY_COMMON_CONJ_BINOP(TEST, ...)                                  \
54   TEST(mulf, ##__VA_ARGS__)                                                    \
55   TEST(mulc, ##__VA_ARGS__)                                                    \
56   TEST(muli, ##__VA_ARGS__)                                                    \
57   TEST(andi, ##__VA_ARGS__)
58 
59 #define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST)                          \
60   FOREVERY_COMMON_CONJ_BINOP(TEST, addf)                                       \
61   FOREVERY_COMMON_CONJ_BINOP(TEST, addc)                                       \
62   FOREVERY_COMMON_CONJ_BINOP(TEST, addi)                                       \
63   FOREVERY_COMMON_CONJ_BINOP(TEST, xori)                                       \
64   FOREVERY_COMMON_CONJ_BINOP(TEST, ori)
65 
66 #define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST)                          \
67   FOREVERY_COMMON_CONJ_BINOP(TEST, mulf)                                       \
68   FOREVERY_COMMON_CONJ_BINOP(TEST, mulc)                                       \
69   FOREVERY_COMMON_CONJ_BINOP(TEST, muli)                                       \
70   FOREVERY_COMMON_CONJ_BINOP(TEST, andi)
71 
72 #define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST)                          \
73   FOREVERY_COMMON_DISJ_BINOP(TEST, addf)                                       \
74   FOREVERY_COMMON_DISJ_BINOP(TEST, addc)                                       \
75   FOREVERY_COMMON_DISJ_BINOP(TEST, addi)                                       \
76   FOREVERY_COMMON_DISJ_BINOP(TEST, ori)                                        \
77   FOREVERY_COMMON_DISJ_BINOP(TEST, xori)
78 
79 ///
80 /// Helper classes/functions for testing Merger.
81 ///
82 
83 /// Simple recursive data structure used to match expressions in Mergers.
84 struct Pattern {
85   Kind kind;
86 
87   /// Expressions representing tensors simply have a tensor number.
88   unsigned tensorNum;
89 
90   /// Tensor operations point to their children.
91   std::shared_ptr<Pattern> e0;
92   std::shared_ptr<Pattern> e1;
93 
94   /// Constructors.
95   /// Rather than using these, please use the readable helper constructor
96   /// functions below to make tests more readable.
Pattern__anonb1ca32140111::Pattern97   Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
Pattern__anonb1ca32140111::Pattern98   Pattern(Kind kind, const std::shared_ptr<Pattern> &e0,
99           const std::shared_ptr<Pattern> &e1)
100       : kind(kind), e0(e0), e1(e1) {
101     assert(kind >= Kind::kMulF);
102     assert(e0 && e1);
103   }
104 };
105 
106 ///
107 /// Readable Pattern builder functions.
108 /// These should be preferred over the actual constructors.
109 ///
110 
tensorPattern(unsigned tensorNum)111 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
112   return std::make_shared<Pattern>(tensorNum);
113 }
114 
115 #define IMPL_BINOP_PATTERN(OP, KIND)                                           \
116   LLVM_ATTRIBUTE_UNUSED static std::shared_ptr<Pattern> OP##Pattern(           \
117       const std::shared_ptr<Pattern> &e0,                                      \
118       const std::shared_ptr<Pattern> &e1) {                                    \
119     return std::make_shared<Pattern>(KIND, e0, e1);                            \
120   }
121 
122 FOREVERY_BINOP(IMPL_BINOP_PATTERN)
123 
124 #undef IMPL_BINOP_PATTERN
125 
126 class MergerTestBase : public ::testing::Test {
127 protected:
MergerTestBase(unsigned numTensors,unsigned numLoops)128   MergerTestBase(unsigned numTensors, unsigned numLoops)
129       : numTensors(numTensors), numLoops(numLoops),
130         merger(numTensors, numLoops) {}
131 
132   ///
133   /// Expression construction helpers.
134   ///
135 
tensor(unsigned tensor)136   unsigned tensor(unsigned tensor) {
137     return merger.addExp(Kind::kTensor, tensor);
138   }
139 
140 #define IMPL_BINOP_EXPR(OP, KIND)                                              \
141   LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) {          \
142     return merger.addExp(KIND, e0, e1);                                        \
143   }
144 
FOREVERY_BINOP(IMPL_BINOP_EXPR)145   FOREVERY_BINOP(IMPL_BINOP_EXPR)
146 
147 #undef IMPL_BINOP_EXPR
148 
149   ///
150   /// Comparison helpers.
151   ///
152 
153   /// For readability of tests.
154   unsigned lat(unsigned lat) { return lat; }
155 
156   /// Returns true if a lattice point with an expression matching the given
157   /// pattern and bits matching the given bits is present in lattice points
158   /// [p, p+n) of lattice set s. This is useful for testing partial ordering
159   /// constraints between lattice points. We generally know how contiguous
160   /// groups of lattice points should be ordered with respect to other groups,
161   /// but there is no required ordering within groups.
162   /// If simple is true, then compare the lat.simple field instead to test the
163   /// result after optimization
latPointWithinRange(unsigned s,unsigned p,unsigned n,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple)164   bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
165                            const std::shared_ptr<Pattern> &pattern,
166                            const BitVector &bits, bool simple) {
167     for (unsigned i = p; i < p + n; ++i) {
168       if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
169           compareBits(s, i, bits, simple))
170         return true;
171     }
172     return false;
173   }
174 
175   /// Wrapper over latPointWithinRange for readability of tests.
expectLatPointWithinRange(unsigned s,unsigned p,unsigned n,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple=false)176   void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
177                                  const std::shared_ptr<Pattern> &pattern,
178                                  const BitVector &bits, bool simple = false) {
179     EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple));
180   }
181 
182   /// Wrapper over expectLatPointWithinRange for a single lat point.
expectLatPoint(unsigned s,unsigned p,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple=false)183   void expectLatPoint(unsigned s, unsigned p,
184                       const std::shared_ptr<Pattern> &pattern,
185                       const BitVector &bits, bool simple = false) {
186     EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple));
187   }
188 
189   /// Converts a vector of (loop, tensor) pairs to a bitvector with the
190   /// corresponding bits set.
191   BitVector
loopsToBits(const std::vector<std::pair<unsigned,unsigned>> & loops)192   loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) {
193     BitVector testBits = BitVector(numTensors + 1, false);
194     for (auto l : loops) {
195       auto loop = std::get<0>(l);
196       auto tensor = std::get<1>(l);
197       testBits.set(numTensors * loop + tensor);
198     }
199     return testBits;
200   }
201 
202   /// Returns true if the bits of lattice point p in set s match the given bits.
203   /// If simple is true, then compare the lat.simple field instead to test the
204   /// result after optimization
compareBits(unsigned s,unsigned p,const BitVector & bits,bool simple)205   bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) {
206     if (simple)
207       return merger.lat(merger.set(s)[p]).simple == bits;
208     return merger.lat(merger.set(s)[p]).bits == bits;
209   }
210 
211   /// Check that there are n lattice points in set s.
expectNumLatPoints(unsigned s,unsigned n)212   void expectNumLatPoints(unsigned s, unsigned n) {
213     EXPECT_THAT(merger.set(s).size(), n);
214   }
215 
216   /// Compares expressions for equality. Equality is defined recursively as:
217   /// - Operations are equal if they have the same kind and children.
218   /// - Leaf tensors are equal if they refer to the same tensor.
compareExpression(unsigned e,const std::shared_ptr<Pattern> & pattern)219   bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
220     auto tensorExp = merger.exp(e);
221     if (tensorExp.kind != pattern->kind)
222       return false;
223     switch (tensorExp.kind) {
224     // Leaf.
225     case kTensor:
226       return tensorExp.tensor == pattern->tensorNum;
227     case kInvariant:
228     case kIndex:
229       llvm_unreachable("invariant not handled yet");
230     // Unary operations.
231     case kAbsF:
232     case kAbsC:
233     case kCeilF:
234     case kFloorF:
235     case kSqrtF:
236     case kSqrtC:
237     case kExpm1F:
238     case kExpm1C:
239     case kLog1pF:
240     case kLog1pC:
241     case kSinF:
242     case kSinC:
243     case kTanhF:
244     case kTanhC:
245     case kNegF:
246     case kNegC:
247     case kNegI:
248     case kTruncF:
249     case kExtF:
250     case kCastFS:
251     case kCastFU:
252     case kCastSF:
253     case kCastUF:
254     case kCastS:
255     case kCastU:
256     case kCastIdx:
257     case kTruncI:
258     case kCIm:
259     case kCRe:
260     case kBitCast:
261     case kBinaryBranch:
262     case kUnary:
263     case kShlI:
264     case kBinary:
265       return compareExpression(tensorExp.children.e0, pattern->e0);
266     // Binary operations.
267     case kMulF:
268     case kMulC:
269     case kMulI:
270     case kDivF:
271     case kDivC:
272     case kDivS:
273     case kDivU:
274     case kAddF:
275     case kAddC:
276     case kAddI:
277     case kSubF:
278     case kSubC:
279     case kSubI:
280     case kAndI:
281     case kOrI:
282     case kXorI:
283     case kShrS:
284     case kShrU:
285       return compareExpression(tensorExp.children.e0, pattern->e0) &&
286              compareExpression(tensorExp.children.e1, pattern->e1);
287     }
288     llvm_unreachable("unexpected kind");
289   }
290 
291   unsigned numTensors;
292   unsigned numLoops;
293   Merger merger;
294 };
295 
296 ///
297 /// Tests with all sparse inputs.
298 ///
299 
300 class MergerTest3T1L : public MergerTestBase {
301 protected:
302   // Our three tensors (two inputs, one output).
303   const unsigned t0 = 0, t1 = 1, t2 = 2;
304 
305   // Our single loop.
306   const unsigned l0 = 0;
307 
MergerTest3T1L()308   MergerTest3T1L() : MergerTestBase(3, 1) {
309     // Tensor 0: sparse input vector.
310     merger.addExp(Kind::kTensor, t0, -1u);
311     merger.setDim(t0, l0, Dim::kSparse);
312 
313     // Tensor 1: sparse input vector.
314     merger.addExp(Kind::kTensor, t1, -1u);
315     merger.setDim(t1, l0, Dim::kSparse);
316 
317     // Tensor 2: dense output vector.
318     merger.addExp(Kind::kTensor, t2, -1u);
319     merger.setDim(t2, l0, Dim::kDense);
320   }
321 };
322 
323 class MergerTest4T1L : public MergerTestBase {
324 protected:
325   // Our four tensors (three inputs, one output).
326   const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
327 
328   // Our single loop.
329   const unsigned l0 = 0;
330 
MergerTest4T1L()331   MergerTest4T1L() : MergerTestBase(4, 1) {
332     // Tensor 0: sparse input vector.
333     merger.addExp(Kind::kTensor, t0, -1u);
334     merger.setDim(t0, l0, Dim::kSparse);
335 
336     // Tensor 1: sparse input vector.
337     merger.addExp(Kind::kTensor, t1, -1u);
338     merger.setDim(t1, l0, Dim::kSparse);
339 
340     // Tensor 2: sparse input vector
341     merger.addExp(Kind::kTensor, t2, -1u);
342     merger.setDim(t2, l0, Dim::kSparse);
343 
344     // Tensor 3: dense output vector
345     merger.addExp(Kind::kTensor, t3, -1u);
346     merger.setDim(t3, l0, Dim::kDense);
347   }
348 };
349 
350 ///
351 /// Tests with both sparse and dense input.
352 ///
353 
354 class MergerTest3T1LD : public MergerTestBase {
355 protected:
356   // Our three tensors (two inputs, one output).
357   const unsigned t0 = 0, t1 = 1, t2 = 2;
358 
359   // Our single loop.
360   const unsigned l0 = 0;
361 
MergerTest3T1LD()362   MergerTest3T1LD() : MergerTestBase(3, 1) {
363     // Tensor 0: sparse input vector.
364     merger.addExp(Kind::kTensor, t0, -1u);
365     merger.setDim(t0, l0, Dim::kSparse);
366 
367     // Tensor 1: dense input vector.
368     merger.addExp(Kind::kTensor, t1, -1u);
369     merger.setDim(t1, l0, Dim::kDense);
370 
371     // Tensor 2: dense output vector.
372     merger.addExp(Kind::kTensor, t2, -1u);
373     merger.setDim(t2, l0, Dim::kDense);
374   }
375 };
376 
377 } // namespace
378 
379 /// Vector addition (disjunction) of 2 vectors. i.e.;
380 ///   a(i) = b(i) + c(i)
381 /// which should form the 3 lattice points
382 /// {
383 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
384 ///   lat( i_00 / tensor_0 )
385 ///   lat( i_01 / tensor_1 )
386 /// }
387 /// and after optimization, the lattice points do not change (as there is no
388 /// duplicated point and all input vectors are sparse vector).
389 /// {
390 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
391 ///   lat( i_00 / tensor_0 )
392 ///   lat( i_01 / tensor_1 )
393 /// }
394 #define IMPL_MERGER_TEST_DISJ(OP)                                              \
395   TEST_F(MergerTest3T1L, vector_##OP) {                                        \
396     auto e = OP##Expr(tensor(t0), tensor(t1));                                 \
397     auto p0 = tensorPattern(t0);                                               \
398     auto p1 = tensorPattern(t1);                                               \
399     auto s = merger.buildLattices(e, l0);                                      \
400                                                                                \
401     expectNumLatPoints(s, 3);                                                  \
402     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
403                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
404     expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}));      \
405     expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}));      \
406                                                                                \
407     s = merger.optimizeSet(s);                                                 \
408     expectNumLatPoints(s, 3);                                                  \
409     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
410                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
411     expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}),       \
412                               true);                                           \
413     expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}),       \
414                               true);                                           \
415   }
416 
417 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
418 
419 #undef IMPL_MERGER_TEST_DISJ
420 
421 /// Vector multiplication (conjunction) of 2 vectors, i.e.;
422 ///   a(i) = b(i) * c(i)
423 /// which should form the single lattice point
424 /// {
425 ///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
426 /// }
427 #define IMPL_MERGER_TEST_CONJ(OP)                                              \
428   TEST_F(MergerTest3T1L, vector_##OP) {                                        \
429     auto e = OP##Expr(t0, t1);                                                 \
430     auto p0 = tensorPattern(t0);                                               \
431     auto p1 = tensorPattern(t1);                                               \
432     auto s = merger.buildLattices(e, l0);                                      \
433                                                                                \
434     expectNumLatPoints(s, 1);                                                  \
435     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
436                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
437                                                                                \
438     s = merger.optimizeSet(s);                                                 \
439     expectNumLatPoints(s, 1);                                                  \
440     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
441                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
442   }
443 
444 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
445 
446 #undef IMPL_MERGER_TEST_CONJ
447 
448 /// Vector multiplication (conjunction) then addition (disjunction), i.e.;
449 ///   a(i) = b(i) * c(i) + d(i);
450 /// which should form
451 /// {
452 ///    lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
453 ///    lat( i_00 i_01 / tensor_0 * tensor_1
454 ///    lat( i_02 / tensor_2 )
455 /// }
456 #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ)                                 \
457   TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) {                             \
458     auto em = CONJ##Expr(t0, t1);                                              \
459     auto e = DISJ##Expr(em, t2);                                               \
460     auto p0 = tensorPattern(t0);                                               \
461     auto p1 = tensorPattern(t1);                                               \
462     auto p2 = tensorPattern(t2);                                               \
463     auto s = merger.buildLattices(e, l0);                                      \
464                                                                                \
465     expectNumLatPoints(s, 3);                                                  \
466     expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2),        \
467                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
468     expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1),             \
469                               loopsToBits({{l0, t0}, {l0, t1}}));              \
470     expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}}));      \
471                                                                                \
472     s = merger.optimizeSet(s);                                                 \
473     expectNumLatPoints(s, 3);                                                  \
474     expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2),        \
475                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
476     expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1),             \
477                               loopsToBits({{l0, t0}, {l0, t1}}));              \
478     expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}}));      \
479   }
480 
481 FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
482 
483 #undef IMPL_MERGER_TEST_CONJ_DISJ
484 
485 /// Vector addition (disjunction) then addition (disjunction), i.e.;
486 ///   a(i) = b(i) + c(i) + d(i)
487 /// which should form
488 /// {
489 ///   lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
490 ///   lat( i_02 i_01 / tensor_2 + tensor_1 )
491 ///   lat( i_02 i_00 / tensor_2 + tensor_0 )
492 ///   lat( i_01 i_00 / tensor_1 + tensor_0 )
493 ///   lat( i_02 / tensor_2 )
494 ///   lat( i_01 / tensor_1 )
495 ///   lat( i_00 / tensor_0 )
496 /// }
497 #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2)                               \
498   TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) {                           \
499     auto em = DISJ1##Expr(t0, t1);                                             \
500     auto e = DISJ2##Expr(em, t2);                                              \
501     auto p0 = tensorPattern(t0);                                               \
502     auto p1 = tensorPattern(t1);                                               \
503     auto p2 = tensorPattern(t2);                                               \
504     auto s = merger.buildLattices(e, l0);                                      \
505                                                                                \
506     expectNumLatPoints(s, 7);                                                  \
507     expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),      \
508                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
509     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2),            \
510                               loopsToBits({{l0, t1}, {l0, t2}}));              \
511     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2),            \
512                               loopsToBits({{l0, t0}, {l0, t2}}));              \
513     expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1),            \
514                               loopsToBits({{l0, t0}, {l0, t1}}));              \
515     expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}}));      \
516     expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}}));      \
517     expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}}));      \
518                                                                                \
519     s = merger.optimizeSet(s);                                                 \
520     expectNumLatPoints(s, 7);                                                  \
521     expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),      \
522                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
523     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2),            \
524                               loopsToBits({{l0, t1}, {l0, t2}}));              \
525     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2),            \
526                               loopsToBits({{l0, t0}, {l0, t2}}));              \
527     expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1),            \
528                               loopsToBits({{l0, t0}, {l0, t1}}));              \
529     expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}}));      \
530     expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}}));      \
531     expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}}));      \
532   }
533 
534 FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
535 
536 #undef IMPL_MERGER_TEST_DISJ_DISJ
537 
538 /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
539 ///   a(i) = b(i) * c(i) * d(i);
540 /// which should form
541 /// {
542 ///    lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
543 /// }
544 #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2)                               \
545   TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) {                           \
546     auto em = CONJ1##Expr(t0, t1);                                             \
547     auto e = CONJ2##Expr(em, t2);                                              \
548     auto p0 = tensorPattern(t0);                                               \
549     auto p1 = tensorPattern(t1);                                               \
550     auto p2 = tensorPattern(t2);                                               \
551     auto s = merger.buildLattices(e, l0);                                      \
552     expectNumLatPoints(s, 1);                                                  \
553     expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
554                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
555     s = merger.optimizeSet(s);                                                 \
556     expectNumLatPoints(s, 1);                                                  \
557     expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
558                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true);         \
559   }
560 
561 FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
562 
563 #undef IMPL_MERGER_TEST_CONJ_CONJ
564 
565 /// Vector addition (disjunction) of 2 vectors, i.e.;
566 ///   a(i) = b(i) + c(i)
567 /// which should form the 3 lattice points
568 /// {
569 ///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
570 ///   lat( i_00 / sparse_tensor_0 )
571 ///   lat( i_01 / dense_tensor_1 )
572 /// }
573 /// which should be optimized to
574 /// {
575 ///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
576 ///   lat( i_01 / dense_tensor_0 ) (no sparse dimension)
577 /// }
578 ///
579 /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
580 /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
581 #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP)                                    \
582   TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
583     auto e = OP##Expr(tensor(t0), tensor(t1));                                 \
584     auto p0 = tensorPattern(t0);                                               \
585     auto p1 = tensorPattern(t1);                                               \
586     auto s = merger.buildLattices(e, l0);                                      \
587                                                                                \
588     expectNumLatPoints(s, 3);                                                  \
589     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
590                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
591     expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}));      \
592     expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}));      \
593                                                                                \
594     s = merger.optimizeSet(s);                                                 \
595     expectNumLatPoints(s, 2);                                                  \
596     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
597                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
598     expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true);              \
599   }
600 
601 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
602 
603 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
604 
605 /// Vector multiplication (conjunction) of 2 vectors, i.e.:
606 ///   a(i) = b(i) * c(i)
607 /// which should form the single lattice point
608 /// {
609 ///   lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
610 /// }
611 /// it should be optimized to
612 /// {
613 ///   lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
614 /// }
615 /// since i_01 is a dense dimension.
616 #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP)                                    \
617   TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
618     auto e = OP##Expr(t0, t1);                                                 \
619     auto p0 = tensorPattern(t0);                                               \
620     auto p1 = tensorPattern(t1);                                               \
621     auto s = merger.buildLattices(e, l0);                                      \
622                                                                                \
623     expectNumLatPoints(s, 1);                                                  \
624     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
625                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
626                                                                                \
627     s = merger.optimizeSet(s);                                                 \
628     expectNumLatPoints(s, 1);                                                  \
629     expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}),    \
630                    true);                                                      \
631   }
632 
633 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
634 
635 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
636 
637 // TODO: mult-dim tests
638 
639 // restore warning status
640 #if defined(_MSC_VER) && !defined(__clang__)
641 #pragma warning(pop)
642 #endif
643