1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
2 #include "llvm/Support/Compiler.h"
3 #include "gmock/gmock.h"
4 #include "gtest/gtest.h"
5 #include <memory>
6 
7 using namespace mlir;
8 using namespace mlir::sparse_tensor;
9 
10 namespace {
11 
12 ///
13 /// Defines macros to iterate binary and the combination of binary operations.
14 ///
15 
16 #define FOREVERY_BINOP(DO)                                                     \
17   DO(mulf, Kind::kMulF)                                                        \
18   DO(mulc, Kind::kMulC)                                                        \
19   DO(muli, Kind::kMulI)                                                        \
20   DO(addf, Kind::kAddF)                                                        \
21   DO(addc, Kind::kAddC)                                                        \
22   DO(addi, Kind::kAddI)                                                        \
23   DO(subf, Kind::kSubF)                                                        \
24   DO(subc, Kind::kSubC)                                                        \
25   DO(subi, Kind::kSubI)                                                        \
26   DO(andi, Kind::kAndI)                                                        \
27   DO(xori, Kind::kXorI)                                                        \
28   DO(ori, Kind::kOrI)
29 
30 // TODO: Disjunctive binary operations that need special handling are not
31 // included, e.g., Division are not tested (for now) as it need a constant
32 // non-zero dividend.
33 // ##__VA_ARGS__ handles cases when __VA_ARGS__ is empty.
34 #define FOREVERY_COMMON_DISJ_BINOP(TEST, ...)                                  \
35   TEST(addf, ##__VA_ARGS__)                                                    \
36   TEST(addc, ##__VA_ARGS__)                                                    \
37   TEST(addi, ##__VA_ARGS__)                                                    \
38   TEST(xori, ##__VA_ARGS__)                                                    \
39   TEST(ori, ##__VA_ARGS__)
40 
41 // TODO: Conjunctive binary operations that need special handling are not
42 // included, e.g., substraction yields a different pattern as it is mapped to
43 // negate operation.
44 #define FOREVERY_COMMON_CONJ_BINOP(TEST, ...)                                  \
45   TEST(mulf, ##__VA_ARGS__)                                                    \
46   TEST(mulc, ##__VA_ARGS__)                                                    \
47   TEST(muli, ##__VA_ARGS__)                                                    \
48   TEST(andi, ##__VA_ARGS__)
49 
50 #define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST)                          \
51   FOREVERY_COMMON_CONJ_BINOP(TEST, addf)                                       \
52   FOREVERY_COMMON_CONJ_BINOP(TEST, addc)                                       \
53   FOREVERY_COMMON_CONJ_BINOP(TEST, addi)                                       \
54   FOREVERY_COMMON_CONJ_BINOP(TEST, xori)                                       \
55   FOREVERY_COMMON_CONJ_BINOP(TEST, ori)
56 
57 #define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST)                          \
58   FOREVERY_COMMON_CONJ_BINOP(TEST, mulf)                                       \
59   FOREVERY_COMMON_CONJ_BINOP(TEST, mulc)                                       \
60   FOREVERY_COMMON_CONJ_BINOP(TEST, muli)                                       \
61   FOREVERY_COMMON_CONJ_BINOP(TEST, andi)
62 
63 #define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST)                          \
64   FOREVERY_COMMON_DISJ_BINOP(TEST, addf)                                       \
65   FOREVERY_COMMON_DISJ_BINOP(TEST, addc)                                       \
66   FOREVERY_COMMON_DISJ_BINOP(TEST, addi)                                       \
67   FOREVERY_COMMON_DISJ_BINOP(TEST, ori)                                        \
68   FOREVERY_COMMON_DISJ_BINOP(TEST, xori)
69 
70 ///
71 /// Helper classes/functions for testing Merger.
72 ///
73 
74 /// Simple recursive data structure used to match expressions in Mergers.
75 struct Pattern {
76   Kind kind;
77 
78   /// Expressions representing tensors simply have a tensor number.
79   unsigned tensorNum;
80 
81   /// Tensor operations point to their children.
82   std::shared_ptr<Pattern> e0;
83   std::shared_ptr<Pattern> e1;
84 
85   /// Constructors.
86   /// Rather than using these, please use the readable helper constructor
87   /// functions below to make tests more readable.
88   Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
89   Pattern(Kind kind, const std::shared_ptr<Pattern> &e0,
90           const std::shared_ptr<Pattern> &e1)
91       : kind(kind), e0(e0), e1(e1) {
92     assert(kind >= Kind::kMulF);
93     assert(e0 && e1);
94   }
95 };
96 
97 ///
98 /// Readable Pattern builder functions.
99 /// These should be preferred over the actual constructors.
100 ///
101 
102 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
103   return std::make_shared<Pattern>(tensorNum);
104 }
105 
106 #define IMPL_BINOP_PATTERN(OP, KIND)                                           \
107   LLVM_ATTRIBUTE_UNUSED static std::shared_ptr<Pattern> OP##Pattern(           \
108       const std::shared_ptr<Pattern> &e0,                                      \
109       const std::shared_ptr<Pattern> &e1) {                                    \
110     return std::make_shared<Pattern>(KIND, e0, e1);                            \
111   }
112 
113 FOREVERY_BINOP(IMPL_BINOP_PATTERN)
114 
115 #undef IMPL_BINOP_PATTERN
116 
117 class MergerTestBase : public ::testing::Test {
118 protected:
119   MergerTestBase(unsigned numTensors, unsigned numLoops)
120       : numTensors(numTensors), numLoops(numLoops),
121         merger(numTensors, numLoops) {}
122 
123   ///
124   /// Expression construction helpers.
125   ///
126 
127   unsigned tensor(unsigned tensor) {
128     return merger.addExp(Kind::kTensor, tensor);
129   }
130 
131 #define IMPL_BINOP_EXPR(OP, KIND)                                              \
132   LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) {          \
133     return merger.addExp(KIND, e0, e1);                                        \
134   }
135 
136   FOREVERY_BINOP(IMPL_BINOP_EXPR)
137 
138 #undef IMPL_BINOP_EXPR
139 
140   ///
141   /// Comparison helpers.
142   ///
143 
144   /// For readability of tests.
145   unsigned lat(unsigned lat) { return lat; }
146 
147   /// Returns true if a lattice point with an expression matching the given
148   /// pattern and bits matching the given bits is present in lattice points
149   /// [p, p+n) of lattice set s. This is useful for testing partial ordering
150   /// constraints between lattice points. We generally know how contiguous
151   /// groups of lattice points should be ordered with respect to other groups,
152   /// but there is no required ordering within groups.
153   /// If simple is true, then compare the lat.simple field instead to test the
154   /// result after optimization
155   bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
156                            const std::shared_ptr<Pattern> &pattern,
157                            const BitVector &bits, bool simple) {
158     for (unsigned i = p; i < p + n; ++i) {
159       if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
160           compareBits(s, i, bits, simple))
161         return true;
162     }
163     return false;
164   }
165 
166   /// Wrapper over latPointWithinRange for readability of tests.
167   void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
168                                  const std::shared_ptr<Pattern> &pattern,
169                                  const BitVector &bits, bool simple = false) {
170     EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple));
171   }
172 
173   /// Wrapper over expectLatPointWithinRange for a single lat point.
174   void expectLatPoint(unsigned s, unsigned p,
175                       const std::shared_ptr<Pattern> &pattern,
176                       const BitVector &bits, bool simple = false) {
177     EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple));
178   }
179 
180   /// Converts a vector of (loop, tensor) pairs to a bitvector with the
181   /// corresponding bits set.
182   BitVector
183   loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) {
184     BitVector testBits = BitVector(numTensors + 1, false);
185     for (auto l : loops) {
186       auto loop = std::get<0>(l);
187       auto tensor = std::get<1>(l);
188       testBits.set(numTensors * loop + tensor);
189     }
190     return testBits;
191   }
192 
193   /// Returns true if the bits of lattice point p in set s match the given bits.
194   /// If simple is true, then compare the lat.simple field instead to test the
195   /// result after optimization
196   bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) {
197     if (simple)
198       return merger.lat(merger.set(s)[p]).simple == bits;
199     return merger.lat(merger.set(s)[p]).bits == bits;
200   }
201 
202   /// Check that there are n lattice points in set s.
203   void expectNumLatPoints(unsigned s, unsigned n) {
204     EXPECT_THAT(merger.set(s).size(), n);
205   }
206 
207   /// Compares expressions for equality. Equality is defined recursively as:
208   /// - Operations are equal if they have the same kind and children.
209   /// - Leaf tensors are equal if they refer to the same tensor.
210   bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
211     auto tensorExp = merger.exp(e);
212     if (tensorExp.kind != pattern->kind)
213       return false;
214     switch (tensorExp.kind) {
215     // Leaf.
216     case kTensor:
217       return tensorExp.tensor == pattern->tensorNum;
218     case kInvariant:
219     case kIndex:
220       llvm_unreachable("invariant not handled yet");
221     // Unary operations.
222     case kAbsF:
223     case kAbsC:
224     case kCeilF:
225     case kFloorF:
226     case kSqrtF:
227     case kSqrtC:
228     case kExpm1F:
229     case kExpm1C:
230     case kLog1pF:
231     case kLog1pC:
232     case kSinF:
233     case kSinC:
234     case kTanhF:
235     case kTanhC:
236     case kNegF:
237     case kNegC:
238     case kNegI:
239     case kTruncF:
240     case kExtF:
241     case kCastFS:
242     case kCastFU:
243     case kCastSF:
244     case kCastUF:
245     case kCastS:
246     case kCastU:
247     case kCastIdx:
248     case kTruncI:
249     case kCIm:
250     case kCRe:
251     case kBitCast:
252     case kBinaryBranch:
253     case kUnary:
254     case kShlI:
255     case kBinary:
256       return compareExpression(tensorExp.children.e0, pattern->e0);
257     // Binary operations.
258     case kMulF:
259     case kMulC:
260     case kMulI:
261     case kDivF:
262     case kDivC:
263     case kDivS:
264     case kDivU:
265     case kAddF:
266     case kAddC:
267     case kAddI:
268     case kSubF:
269     case kSubC:
270     case kSubI:
271     case kAndI:
272     case kOrI:
273     case kXorI:
274     case kShrS:
275     case kShrU:
276       return compareExpression(tensorExp.children.e0, pattern->e0) &&
277              compareExpression(tensorExp.children.e1, pattern->e1);
278     }
279     llvm_unreachable("unexpected kind");
280   }
281 
282   unsigned numTensors;
283   unsigned numLoops;
284   Merger merger;
285 };
286 
287 ///
288 /// Tests with all sparse inputs.
289 ///
290 
291 class MergerTest3T1L : public MergerTestBase {
292 protected:
293   // Our three tensors (two inputs, one output).
294   const unsigned t0 = 0, t1 = 1, t2 = 2;
295 
296   // Our single loop.
297   const unsigned l0 = 0;
298 
299   MergerTest3T1L() : MergerTestBase(3, 1) {
300     // Tensor 0: sparse input vector.
301     merger.addExp(Kind::kTensor, t0, -1u);
302     merger.setDim(t0, l0, Dim::kSparse);
303 
304     // Tensor 1: sparse input vector.
305     merger.addExp(Kind::kTensor, t1, -1u);
306     merger.setDim(t1, l0, Dim::kSparse);
307 
308     // Tensor 2: dense output vector.
309     merger.addExp(Kind::kTensor, t2, -1u);
310     merger.setDim(t2, l0, Dim::kDense);
311   }
312 };
313 
314 class MergerTest4T1L : public MergerTestBase {
315 protected:
316   // Our four tensors (three inputs, one output).
317   const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
318 
319   // Our single loop.
320   const unsigned l0 = 0;
321 
322   MergerTest4T1L() : MergerTestBase(4, 1) {
323     // Tensor 0: sparse input vector.
324     merger.addExp(Kind::kTensor, t0, -1u);
325     merger.setDim(t0, l0, Dim::kSparse);
326 
327     // Tensor 1: sparse input vector.
328     merger.addExp(Kind::kTensor, t1, -1u);
329     merger.setDim(t1, l0, Dim::kSparse);
330 
331     // Tensor 2: sparse input vector
332     merger.addExp(Kind::kTensor, t2, -1u);
333     merger.setDim(t2, l0, Dim::kSparse);
334 
335     // Tensor 3: dense output vector
336     merger.addExp(Kind::kTensor, t3, -1u);
337     merger.setDim(t3, l0, Dim::kDense);
338   }
339 };
340 
341 ///
342 /// Tests with both sparse and dense input.
343 ///
344 
345 class MergerTest3T1LD : public MergerTestBase {
346 protected:
347   // Our three tensors (two inputs, one output).
348   const unsigned t0 = 0, t1 = 1, t2 = 2;
349 
350   // Our single loop.
351   const unsigned l0 = 0;
352 
353   MergerTest3T1LD() : MergerTestBase(3, 1) {
354     // Tensor 0: sparse input vector.
355     merger.addExp(Kind::kTensor, t0, -1u);
356     merger.setDim(t0, l0, Dim::kSparse);
357 
358     // Tensor 1: dense input vector.
359     merger.addExp(Kind::kTensor, t1, -1u);
360     merger.setDim(t1, l0, Dim::kDense);
361 
362     // Tensor 2: dense output vector.
363     merger.addExp(Kind::kTensor, t2, -1u);
364     merger.setDim(t2, l0, Dim::kDense);
365   }
366 };
367 
368 } // namespace
369 
370 /// Vector addition (disjunction) of 2 vectors. i.e.;
371 ///   a(i) = b(i) + c(i)
372 /// which should form the 3 lattice points
373 /// {
374 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
375 ///   lat( i_00 / tensor_0 )
376 ///   lat( i_01 / tensor_1 )
377 /// }
378 /// and after optimization, the lattice points do not change (as there is no
379 /// duplicated point and all input vectors are sparse vector).
380 /// {
381 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
382 ///   lat( i_00 / tensor_0 )
383 ///   lat( i_01 / tensor_1 )
384 /// }
385 #define IMPL_MERGER_TEST_DISJ(OP)                                              \
386   TEST_F(MergerTest3T1L, vector_##OP) {                                        \
387     auto e = OP##Expr(tensor(t0), tensor(t1));                                 \
388     auto p0 = tensorPattern(t0);                                               \
389     auto p1 = tensorPattern(t1);                                               \
390     auto s = merger.buildLattices(e, l0);                                      \
391                                                                                \
392     expectNumLatPoints(s, 3);                                                  \
393     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
394                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
395     expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}));      \
396     expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}));      \
397                                                                                \
398     s = merger.optimizeSet(s);                                                 \
399     expectNumLatPoints(s, 3);                                                  \
400     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
401                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
402     expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}),       \
403                               true);                                           \
404     expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}),       \
405                               true);                                           \
406   }
407 
408 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
409 
410 #undef IMPL_MERGER_TEST_DISJ
411 
412 /// Vector multiplication (conjunction) of 2 vectors, i.e.;
413 ///   a(i) = b(i) * c(i)
414 /// which should form the single lattice point
415 /// {
416 ///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
417 /// }
418 #define IMPL_MERGER_TEST_CONJ(OP)                                              \
419   TEST_F(MergerTest3T1L, vector_##OP) {                                        \
420     auto e = OP##Expr(t0, t1);                                                 \
421     auto p0 = tensorPattern(t0);                                               \
422     auto p1 = tensorPattern(t1);                                               \
423     auto s = merger.buildLattices(e, l0);                                      \
424                                                                                \
425     expectNumLatPoints(s, 1);                                                  \
426     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
427                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
428                                                                                \
429     s = merger.optimizeSet(s);                                                 \
430     expectNumLatPoints(s, 1);                                                  \
431     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
432                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
433   }
434 
435 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
436 
437 #undef IMPL_MERGER_TEST_CONJ
438 
439 /// Vector multiplication (conjunction) then addition (disjunction), i.e.;
440 ///   a(i) = b(i) * c(i) + d(i);
441 /// which should form
442 /// {
443 ///    lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
444 ///    lat( i_00 i_01 / tensor_0 * tensor_1
445 ///    lat( i_02 / tensor_2 )
446 /// }
447 #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ)                                 \
448   TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) {                             \
449     auto em = CONJ##Expr(t0, t1);                                              \
450     auto e = DISJ##Expr(em, t2);                                               \
451     auto p0 = tensorPattern(t0);                                               \
452     auto p1 = tensorPattern(t1);                                               \
453     auto p2 = tensorPattern(t2);                                               \
454     auto s = merger.buildLattices(e, l0);                                      \
455                                                                                \
456     expectNumLatPoints(s, 3);                                                  \
457     expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2),        \
458                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
459     expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1),             \
460                               loopsToBits({{l0, t0}, {l0, t1}}));              \
461     expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}}));      \
462                                                                                \
463     s = merger.optimizeSet(s);                                                 \
464     expectNumLatPoints(s, 3);                                                  \
465     expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2),        \
466                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
467     expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1),             \
468                               loopsToBits({{l0, t0}, {l0, t1}}));              \
469     expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}}));      \
470   }
471 
472 FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
473 
474 #undef IMPL_MERGER_TEST_CONJ_DISJ
475 
476 /// Vector addition (disjunction) then addition (disjunction), i.e.;
477 ///   a(i) = b(i) + c(i) + d(i)
478 /// which should form
479 /// {
480 ///   lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
481 ///   lat( i_02 i_01 / tensor_2 + tensor_1 )
482 ///   lat( i_02 i_00 / tensor_2 + tensor_0 )
483 ///   lat( i_01 i_00 / tensor_1 + tensor_0 )
484 ///   lat( i_02 / tensor_2 )
485 ///   lat( i_01 / tensor_1 )
486 ///   lat( i_00 / tensor_0 )
487 /// }
488 #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2)                               \
489   TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) {                           \
490     auto em = DISJ1##Expr(t0, t1);                                             \
491     auto e = DISJ2##Expr(em, t2);                                              \
492     auto p0 = tensorPattern(t0);                                               \
493     auto p1 = tensorPattern(t1);                                               \
494     auto p2 = tensorPattern(t2);                                               \
495     auto s = merger.buildLattices(e, l0);                                      \
496                                                                                \
497     expectNumLatPoints(s, 7);                                                  \
498     expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),      \
499                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
500     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2),            \
501                               loopsToBits({{l0, t1}, {l0, t2}}));              \
502     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2),            \
503                               loopsToBits({{l0, t0}, {l0, t2}}));              \
504     expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1),            \
505                               loopsToBits({{l0, t0}, {l0, t1}}));              \
506     expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}}));      \
507     expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}}));      \
508     expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}}));      \
509                                                                                \
510     s = merger.optimizeSet(s);                                                 \
511     expectNumLatPoints(s, 7);                                                  \
512     expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2),      \
513                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
514     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2),            \
515                               loopsToBits({{l0, t1}, {l0, t2}}));              \
516     expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2),            \
517                               loopsToBits({{l0, t0}, {l0, t2}}));              \
518     expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1),            \
519                               loopsToBits({{l0, t0}, {l0, t1}}));              \
520     expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}}));      \
521     expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}}));      \
522     expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}}));      \
523   }
524 
525 FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
526 
527 #undef IMPL_MERGER_TEST_DISJ_DISJ
528 
529 /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
530 ///   a(i) = b(i) * c(i) * d(i);
531 /// which should form
532 /// {
533 ///    lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
534 /// }
535 #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2)                               \
536   TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) {                           \
537     auto em = CONJ1##Expr(t0, t1);                                             \
538     auto e = CONJ2##Expr(em, t2);                                              \
539     auto p0 = tensorPattern(t0);                                               \
540     auto p1 = tensorPattern(t1);                                               \
541     auto p2 = tensorPattern(t2);                                               \
542     auto s = merger.buildLattices(e, l0);                                      \
543     expectNumLatPoints(s, 1);                                                  \
544     expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
545                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}));               \
546     s = merger.optimizeSet(s);                                                 \
547     expectNumLatPoints(s, 1);                                                  \
548     expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2),      \
549                    loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true);         \
550   }
551 
552 FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
553 
554 #undef IMPL_MERGER_TEST_CONJ_CONJ
555 
556 /// Vector addition (disjunction) of 2 vectors, i.e.;
557 ///   a(i) = b(i) + c(i)
558 /// which should form the 3 lattice points
559 /// {
560 ///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
561 ///   lat( i_00 / sparse_tensor_0 )
562 ///   lat( i_01 / dense_tensor_1 )
563 /// }
564 /// which should be optimized to
565 /// {
566 ///   lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
567 ///   lat( i_01 / dense_tensor_0 ) (no sparse dimension)
568 /// }
569 ///
570 /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
571 /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
572 #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP)                                    \
573   TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
574     auto e = OP##Expr(tensor(t0), tensor(t1));                                 \
575     auto p0 = tensorPattern(t0);                                               \
576     auto p1 = tensorPattern(t1);                                               \
577     auto s = merger.buildLattices(e, l0);                                      \
578                                                                                \
579     expectNumLatPoints(s, 3);                                                  \
580     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
581                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
582     expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}));      \
583     expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}));      \
584                                                                                \
585     s = merger.optimizeSet(s);                                                 \
586     expectNumLatPoints(s, 2);                                                  \
587     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
588                    loopsToBits({{l0, t0}, {l0, t1}}), true);                   \
589     expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true);              \
590   }
591 
592 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
593 
594 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
595 
596 /// Vector multiplication (conjunction) of 2 vectors, i.e.:
597 ///   a(i) = b(i) * c(i)
598 /// which should form the single lattice point
599 /// {
600 ///   lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
601 /// }
602 /// it should be optimized to
603 /// {
604 ///   lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
605 /// }
606 /// since i_01 is a dense dimension.
607 #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP)                                    \
608   TEST_F(MergerTest3T1LD, vector_opted_##OP) {                                 \
609     auto e = OP##Expr(t0, t1);                                                 \
610     auto p0 = tensorPattern(t0);                                               \
611     auto p1 = tensorPattern(t1);                                               \
612     auto s = merger.buildLattices(e, l0);                                      \
613                                                                                \
614     expectNumLatPoints(s, 1);                                                  \
615     expectLatPoint(s, lat(0), OP##Pattern(p0, p1),                             \
616                    loopsToBits({{l0, t0}, {l0, t1}}));                         \
617                                                                                \
618     s = merger.optimizeSet(s);                                                 \
619     expectNumLatPoints(s, 1);                                                  \
620     expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}),    \
621                    true);                                                      \
622   }
623 
624 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
625 
626 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
627 
628 // TODO: mult-dim tests
629