1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
2 #include "llvm/Support/Compiler.h"
3 #include "gmock/gmock.h"
4 #include "gtest/gtest.h"
5 #include <memory>
6
7 using namespace mlir;
8 using namespace mlir::sparse_tensor;
9
10 // Silence 'warning C4002: 'too many arguments for function-liked macro
11 // invocation'
12 // as MSVC handles ##__VA_ARGS__ differently as gcc/clang
13
14 #if defined(_MSC_VER) && !defined(__clang__)
15 #pragma warning(push)
16 #pragma warning(disable : 4002)
17 #endif
18
19 namespace {
20
21 ///
22 /// Defines macros to iterate binary and the combination of binary operations.
23 ///
24
25 #define FOREVERY_BINOP(DO) \
26 DO(mulf, Kind::kMulF) \
27 DO(mulc, Kind::kMulC) \
28 DO(muli, Kind::kMulI) \
29 DO(addf, Kind::kAddF) \
30 DO(addc, Kind::kAddC) \
31 DO(addi, Kind::kAddI) \
32 DO(subf, Kind::kSubF) \
33 DO(subc, Kind::kSubC) \
34 DO(subi, Kind::kSubI) \
35 DO(andi, Kind::kAndI) \
36 DO(xori, Kind::kXorI) \
37 DO(ori, Kind::kOrI)
38
39 // TODO: Disjunctive binary operations that need special handling are not
40 // included, e.g., Division are not tested (for now) as it need a constant
41 // non-zero dividend.
42 // ##__VA_ARGS__ handles cases when __VA_ARGS__ is empty.
43 #define FOREVERY_COMMON_DISJ_BINOP(TEST, ...) \
44 TEST(addf, ##__VA_ARGS__) \
45 TEST(addc, ##__VA_ARGS__) \
46 TEST(addi, ##__VA_ARGS__) \
47 TEST(xori, ##__VA_ARGS__) \
48 TEST(ori, ##__VA_ARGS__)
49
50 // TODO: Conjunctive binary operations that need special handling are not
51 // included, e.g., substraction yields a different pattern as it is mapped to
52 // negate operation.
53 #define FOREVERY_COMMON_CONJ_BINOP(TEST, ...) \
54 TEST(mulf, ##__VA_ARGS__) \
55 TEST(mulc, ##__VA_ARGS__) \
56 TEST(muli, ##__VA_ARGS__) \
57 TEST(andi, ##__VA_ARGS__)
58
59 #define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST) \
60 FOREVERY_COMMON_CONJ_BINOP(TEST, addf) \
61 FOREVERY_COMMON_CONJ_BINOP(TEST, addc) \
62 FOREVERY_COMMON_CONJ_BINOP(TEST, addi) \
63 FOREVERY_COMMON_CONJ_BINOP(TEST, xori) \
64 FOREVERY_COMMON_CONJ_BINOP(TEST, ori)
65
66 #define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST) \
67 FOREVERY_COMMON_CONJ_BINOP(TEST, mulf) \
68 FOREVERY_COMMON_CONJ_BINOP(TEST, mulc) \
69 FOREVERY_COMMON_CONJ_BINOP(TEST, muli) \
70 FOREVERY_COMMON_CONJ_BINOP(TEST, andi)
71
72 #define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST) \
73 FOREVERY_COMMON_DISJ_BINOP(TEST, addf) \
74 FOREVERY_COMMON_DISJ_BINOP(TEST, addc) \
75 FOREVERY_COMMON_DISJ_BINOP(TEST, addi) \
76 FOREVERY_COMMON_DISJ_BINOP(TEST, ori) \
77 FOREVERY_COMMON_DISJ_BINOP(TEST, xori)
78
79 ///
80 /// Helper classes/functions for testing Merger.
81 ///
82
83 /// Simple recursive data structure used to match expressions in Mergers.
84 struct Pattern {
85 Kind kind;
86
87 /// Expressions representing tensors simply have a tensor number.
88 unsigned tensorNum;
89
90 /// Tensor operations point to their children.
91 std::shared_ptr<Pattern> e0;
92 std::shared_ptr<Pattern> e1;
93
94 /// Constructors.
95 /// Rather than using these, please use the readable helper constructor
96 /// functions below to make tests more readable.
Pattern__anonb1ca32140111::Pattern97 Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
Pattern__anonb1ca32140111::Pattern98 Pattern(Kind kind, const std::shared_ptr<Pattern> &e0,
99 const std::shared_ptr<Pattern> &e1)
100 : kind(kind), e0(e0), e1(e1) {
101 assert(kind >= Kind::kMulF);
102 assert(e0 && e1);
103 }
104 };
105
106 ///
107 /// Readable Pattern builder functions.
108 /// These should be preferred over the actual constructors.
109 ///
110
tensorPattern(unsigned tensorNum)111 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
112 return std::make_shared<Pattern>(tensorNum);
113 }
114
115 #define IMPL_BINOP_PATTERN(OP, KIND) \
116 LLVM_ATTRIBUTE_UNUSED static std::shared_ptr<Pattern> OP##Pattern( \
117 const std::shared_ptr<Pattern> &e0, \
118 const std::shared_ptr<Pattern> &e1) { \
119 return std::make_shared<Pattern>(KIND, e0, e1); \
120 }
121
122 FOREVERY_BINOP(IMPL_BINOP_PATTERN)
123
124 #undef IMPL_BINOP_PATTERN
125
126 class MergerTestBase : public ::testing::Test {
127 protected:
MergerTestBase(unsigned numTensors,unsigned numLoops)128 MergerTestBase(unsigned numTensors, unsigned numLoops)
129 : numTensors(numTensors), numLoops(numLoops),
130 merger(numTensors, numLoops) {}
131
132 ///
133 /// Expression construction helpers.
134 ///
135
tensor(unsigned tensor)136 unsigned tensor(unsigned tensor) {
137 return merger.addExp(Kind::kTensor, tensor);
138 }
139
140 #define IMPL_BINOP_EXPR(OP, KIND) \
141 LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) { \
142 return merger.addExp(KIND, e0, e1); \
143 }
144
FOREVERY_BINOP(IMPL_BINOP_EXPR)145 FOREVERY_BINOP(IMPL_BINOP_EXPR)
146
147 #undef IMPL_BINOP_EXPR
148
149 ///
150 /// Comparison helpers.
151 ///
152
153 /// For readability of tests.
154 unsigned lat(unsigned lat) { return lat; }
155
156 /// Returns true if a lattice point with an expression matching the given
157 /// pattern and bits matching the given bits is present in lattice points
158 /// [p, p+n) of lattice set s. This is useful for testing partial ordering
159 /// constraints between lattice points. We generally know how contiguous
160 /// groups of lattice points should be ordered with respect to other groups,
161 /// but there is no required ordering within groups.
162 /// If simple is true, then compare the lat.simple field instead to test the
163 /// result after optimization
latPointWithinRange(unsigned s,unsigned p,unsigned n,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple)164 bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
165 const std::shared_ptr<Pattern> &pattern,
166 const BitVector &bits, bool simple) {
167 for (unsigned i = p; i < p + n; ++i) {
168 if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
169 compareBits(s, i, bits, simple))
170 return true;
171 }
172 return false;
173 }
174
175 /// Wrapper over latPointWithinRange for readability of tests.
expectLatPointWithinRange(unsigned s,unsigned p,unsigned n,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple=false)176 void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
177 const std::shared_ptr<Pattern> &pattern,
178 const BitVector &bits, bool simple = false) {
179 EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple));
180 }
181
182 /// Wrapper over expectLatPointWithinRange for a single lat point.
expectLatPoint(unsigned s,unsigned p,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple=false)183 void expectLatPoint(unsigned s, unsigned p,
184 const std::shared_ptr<Pattern> &pattern,
185 const BitVector &bits, bool simple = false) {
186 EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple));
187 }
188
189 /// Converts a vector of (loop, tensor) pairs to a bitvector with the
190 /// corresponding bits set.
191 BitVector
loopsToBits(const std::vector<std::pair<unsigned,unsigned>> & loops)192 loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) {
193 BitVector testBits = BitVector(numTensors + 1, false);
194 for (auto l : loops) {
195 auto loop = std::get<0>(l);
196 auto tensor = std::get<1>(l);
197 testBits.set(numTensors * loop + tensor);
198 }
199 return testBits;
200 }
201
202 /// Returns true if the bits of lattice point p in set s match the given bits.
203 /// If simple is true, then compare the lat.simple field instead to test the
204 /// result after optimization
compareBits(unsigned s,unsigned p,const BitVector & bits,bool simple)205 bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) {
206 if (simple)
207 return merger.lat(merger.set(s)[p]).simple == bits;
208 return merger.lat(merger.set(s)[p]).bits == bits;
209 }
210
211 /// Check that there are n lattice points in set s.
expectNumLatPoints(unsigned s,unsigned n)212 void expectNumLatPoints(unsigned s, unsigned n) {
213 EXPECT_THAT(merger.set(s).size(), n);
214 }
215
216 /// Compares expressions for equality. Equality is defined recursively as:
217 /// - Operations are equal if they have the same kind and children.
218 /// - Leaf tensors are equal if they refer to the same tensor.
compareExpression(unsigned e,const std::shared_ptr<Pattern> & pattern)219 bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
220 auto tensorExp = merger.exp(e);
221 if (tensorExp.kind != pattern->kind)
222 return false;
223 switch (tensorExp.kind) {
224 // Leaf.
225 case kTensor:
226 return tensorExp.tensor == pattern->tensorNum;
227 case kInvariant:
228 case kIndex:
229 llvm_unreachable("invariant not handled yet");
230 // Unary operations.
231 case kAbsF:
232 case kAbsC:
233 case kCeilF:
234 case kFloorF:
235 case kSqrtF:
236 case kSqrtC:
237 case kExpm1F:
238 case kExpm1C:
239 case kLog1pF:
240 case kLog1pC:
241 case kSinF:
242 case kSinC:
243 case kTanhF:
244 case kTanhC:
245 case kNegF:
246 case kNegC:
247 case kNegI:
248 case kTruncF:
249 case kExtF:
250 case kCastFS:
251 case kCastFU:
252 case kCastSF:
253 case kCastUF:
254 case kCastS:
255 case kCastU:
256 case kCastIdx:
257 case kTruncI:
258 case kCIm:
259 case kCRe:
260 case kBitCast:
261 case kBinaryBranch:
262 case kUnary:
263 case kShlI:
264 case kBinary:
265 return compareExpression(tensorExp.children.e0, pattern->e0);
266 // Binary operations.
267 case kMulF:
268 case kMulC:
269 case kMulI:
270 case kDivF:
271 case kDivC:
272 case kDivS:
273 case kDivU:
274 case kAddF:
275 case kAddC:
276 case kAddI:
277 case kSubF:
278 case kSubC:
279 case kSubI:
280 case kAndI:
281 case kOrI:
282 case kXorI:
283 case kShrS:
284 case kShrU:
285 return compareExpression(tensorExp.children.e0, pattern->e0) &&
286 compareExpression(tensorExp.children.e1, pattern->e1);
287 }
288 llvm_unreachable("unexpected kind");
289 }
290
291 unsigned numTensors;
292 unsigned numLoops;
293 Merger merger;
294 };
295
296 ///
297 /// Tests with all sparse inputs.
298 ///
299
300 class MergerTest3T1L : public MergerTestBase {
301 protected:
302 // Our three tensors (two inputs, one output).
303 const unsigned t0 = 0, t1 = 1, t2 = 2;
304
305 // Our single loop.
306 const unsigned l0 = 0;
307
MergerTest3T1L()308 MergerTest3T1L() : MergerTestBase(3, 1) {
309 // Tensor 0: sparse input vector.
310 merger.addExp(Kind::kTensor, t0, -1u);
311 merger.setDim(t0, l0, Dim::kSparse);
312
313 // Tensor 1: sparse input vector.
314 merger.addExp(Kind::kTensor, t1, -1u);
315 merger.setDim(t1, l0, Dim::kSparse);
316
317 // Tensor 2: dense output vector.
318 merger.addExp(Kind::kTensor, t2, -1u);
319 merger.setDim(t2, l0, Dim::kDense);
320 }
321 };
322
323 class MergerTest4T1L : public MergerTestBase {
324 protected:
325 // Our four tensors (three inputs, one output).
326 const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
327
328 // Our single loop.
329 const unsigned l0 = 0;
330
MergerTest4T1L()331 MergerTest4T1L() : MergerTestBase(4, 1) {
332 // Tensor 0: sparse input vector.
333 merger.addExp(Kind::kTensor, t0, -1u);
334 merger.setDim(t0, l0, Dim::kSparse);
335
336 // Tensor 1: sparse input vector.
337 merger.addExp(Kind::kTensor, t1, -1u);
338 merger.setDim(t1, l0, Dim::kSparse);
339
340 // Tensor 2: sparse input vector
341 merger.addExp(Kind::kTensor, t2, -1u);
342 merger.setDim(t2, l0, Dim::kSparse);
343
344 // Tensor 3: dense output vector
345 merger.addExp(Kind::kTensor, t3, -1u);
346 merger.setDim(t3, l0, Dim::kDense);
347 }
348 };
349
350 ///
351 /// Tests with both sparse and dense input.
352 ///
353
354 class MergerTest3T1LD : public MergerTestBase {
355 protected:
356 // Our three tensors (two inputs, one output).
357 const unsigned t0 = 0, t1 = 1, t2 = 2;
358
359 // Our single loop.
360 const unsigned l0 = 0;
361
MergerTest3T1LD()362 MergerTest3T1LD() : MergerTestBase(3, 1) {
363 // Tensor 0: sparse input vector.
364 merger.addExp(Kind::kTensor, t0, -1u);
365 merger.setDim(t0, l0, Dim::kSparse);
366
367 // Tensor 1: dense input vector.
368 merger.addExp(Kind::kTensor, t1, -1u);
369 merger.setDim(t1, l0, Dim::kDense);
370
371 // Tensor 2: dense output vector.
372 merger.addExp(Kind::kTensor, t2, -1u);
373 merger.setDim(t2, l0, Dim::kDense);
374 }
375 };
376
377 } // namespace
378
379 /// Vector addition (disjunction) of 2 vectors. i.e.;
380 /// a(i) = b(i) + c(i)
381 /// which should form the 3 lattice points
382 /// {
383 /// lat( i_00 i_01 / (tensor_0 + tensor_1) )
384 /// lat( i_00 / tensor_0 )
385 /// lat( i_01 / tensor_1 )
386 /// }
387 /// and after optimization, the lattice points do not change (as there is no
388 /// duplicated point and all input vectors are sparse vector).
389 /// {
390 /// lat( i_00 i_01 / (tensor_0 + tensor_1) )
391 /// lat( i_00 / tensor_0 )
392 /// lat( i_01 / tensor_1 )
393 /// }
394 #define IMPL_MERGER_TEST_DISJ(OP) \
395 TEST_F(MergerTest3T1L, vector_##OP) { \
396 auto e = OP##Expr(tensor(t0), tensor(t1)); \
397 auto p0 = tensorPattern(t0); \
398 auto p1 = tensorPattern(t1); \
399 auto s = merger.buildLattices(e, l0); \
400 \
401 expectNumLatPoints(s, 3); \
402 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
403 loopsToBits({{l0, t0}, {l0, t1}})); \
404 expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \
405 expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \
406 \
407 s = merger.optimizeSet(s); \
408 expectNumLatPoints(s, 3); \
409 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
410 loopsToBits({{l0, t0}, {l0, t1}}), true); \
411 expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}), \
412 true); \
413 expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}), \
414 true); \
415 }
416
417 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
418
419 #undef IMPL_MERGER_TEST_DISJ
420
421 /// Vector multiplication (conjunction) of 2 vectors, i.e.;
422 /// a(i) = b(i) * c(i)
423 /// which should form the single lattice point
424 /// {
425 /// lat( i_00 i_01 / (tensor_0 * tensor_1) )
426 /// }
427 #define IMPL_MERGER_TEST_CONJ(OP) \
428 TEST_F(MergerTest3T1L, vector_##OP) { \
429 auto e = OP##Expr(t0, t1); \
430 auto p0 = tensorPattern(t0); \
431 auto p1 = tensorPattern(t1); \
432 auto s = merger.buildLattices(e, l0); \
433 \
434 expectNumLatPoints(s, 1); \
435 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
436 loopsToBits({{l0, t0}, {l0, t1}})); \
437 \
438 s = merger.optimizeSet(s); \
439 expectNumLatPoints(s, 1); \
440 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
441 loopsToBits({{l0, t0}, {l0, t1}}), true); \
442 }
443
444 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
445
446 #undef IMPL_MERGER_TEST_CONJ
447
448 /// Vector multiplication (conjunction) then addition (disjunction), i.e.;
449 /// a(i) = b(i) * c(i) + d(i);
450 /// which should form
451 /// {
452 /// lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
453 /// lat( i_00 i_01 / tensor_0 * tensor_1
454 /// lat( i_02 / tensor_2 )
455 /// }
456 #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
457 TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
458 auto em = CONJ##Expr(t0, t1); \
459 auto e = DISJ##Expr(em, t2); \
460 auto p0 = tensorPattern(t0); \
461 auto p1 = tensorPattern(t1); \
462 auto p2 = tensorPattern(t2); \
463 auto s = merger.buildLattices(e, l0); \
464 \
465 expectNumLatPoints(s, 3); \
466 expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
467 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
468 expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \
469 loopsToBits({{l0, t0}, {l0, t1}})); \
470 expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \
471 \
472 s = merger.optimizeSet(s); \
473 expectNumLatPoints(s, 3); \
474 expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
475 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
476 expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \
477 loopsToBits({{l0, t0}, {l0, t1}})); \
478 expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \
479 }
480
481 FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
482
483 #undef IMPL_MERGER_TEST_CONJ_DISJ
484
485 /// Vector addition (disjunction) then addition (disjunction), i.e.;
486 /// a(i) = b(i) + c(i) + d(i)
487 /// which should form
488 /// {
489 /// lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
490 /// lat( i_02 i_01 / tensor_2 + tensor_1 )
491 /// lat( i_02 i_00 / tensor_2 + tensor_0 )
492 /// lat( i_01 i_00 / tensor_1 + tensor_0 )
493 /// lat( i_02 / tensor_2 )
494 /// lat( i_01 / tensor_1 )
495 /// lat( i_00 / tensor_0 )
496 /// }
497 #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
498 TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
499 auto em = DISJ1##Expr(t0, t1); \
500 auto e = DISJ2##Expr(em, t2); \
501 auto p0 = tensorPattern(t0); \
502 auto p1 = tensorPattern(t1); \
503 auto p2 = tensorPattern(t2); \
504 auto s = merger.buildLattices(e, l0); \
505 \
506 expectNumLatPoints(s, 7); \
507 expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
508 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
509 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \
510 loopsToBits({{l0, t1}, {l0, t2}})); \
511 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \
512 loopsToBits({{l0, t0}, {l0, t2}})); \
513 expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \
514 loopsToBits({{l0, t0}, {l0, t1}})); \
515 expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \
516 expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \
517 expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \
518 \
519 s = merger.optimizeSet(s); \
520 expectNumLatPoints(s, 7); \
521 expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
522 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
523 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \
524 loopsToBits({{l0, t1}, {l0, t2}})); \
525 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \
526 loopsToBits({{l0, t0}, {l0, t2}})); \
527 expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \
528 loopsToBits({{l0, t0}, {l0, t1}})); \
529 expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \
530 expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \
531 expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \
532 }
533
534 FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
535
536 #undef IMPL_MERGER_TEST_DISJ_DISJ
537
538 /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
539 /// a(i) = b(i) * c(i) * d(i);
540 /// which should form
541 /// {
542 /// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
543 /// }
544 #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
545 TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
546 auto em = CONJ1##Expr(t0, t1); \
547 auto e = CONJ2##Expr(em, t2); \
548 auto p0 = tensorPattern(t0); \
549 auto p1 = tensorPattern(t1); \
550 auto p2 = tensorPattern(t2); \
551 auto s = merger.buildLattices(e, l0); \
552 expectNumLatPoints(s, 1); \
553 expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
554 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
555 s = merger.optimizeSet(s); \
556 expectNumLatPoints(s, 1); \
557 expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
558 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \
559 }
560
561 FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
562
563 #undef IMPL_MERGER_TEST_CONJ_CONJ
564
565 /// Vector addition (disjunction) of 2 vectors, i.e.;
566 /// a(i) = b(i) + c(i)
567 /// which should form the 3 lattice points
568 /// {
569 /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
570 /// lat( i_00 / sparse_tensor_0 )
571 /// lat( i_01 / dense_tensor_1 )
572 /// }
573 /// which should be optimized to
574 /// {
575 /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
576 /// lat( i_01 / dense_tensor_0 ) (no sparse dimension)
577 /// }
578 ///
579 /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
580 /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
581 #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP) \
582 TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
583 auto e = OP##Expr(tensor(t0), tensor(t1)); \
584 auto p0 = tensorPattern(t0); \
585 auto p1 = tensorPattern(t1); \
586 auto s = merger.buildLattices(e, l0); \
587 \
588 expectNumLatPoints(s, 3); \
589 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
590 loopsToBits({{l0, t0}, {l0, t1}})); \
591 expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \
592 expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \
593 \
594 s = merger.optimizeSet(s); \
595 expectNumLatPoints(s, 2); \
596 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
597 loopsToBits({{l0, t0}, {l0, t1}}), true); \
598 expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true); \
599 }
600
601 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
602
603 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
604
605 /// Vector multiplication (conjunction) of 2 vectors, i.e.:
606 /// a(i) = b(i) * c(i)
607 /// which should form the single lattice point
608 /// {
609 /// lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
610 /// }
611 /// it should be optimized to
612 /// {
613 /// lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
614 /// }
615 /// since i_01 is a dense dimension.
616 #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP) \
617 TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
618 auto e = OP##Expr(t0, t1); \
619 auto p0 = tensorPattern(t0); \
620 auto p1 = tensorPattern(t1); \
621 auto s = merger.buildLattices(e, l0); \
622 \
623 expectNumLatPoints(s, 1); \
624 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
625 loopsToBits({{l0, t0}, {l0, t1}})); \
626 \
627 s = merger.optimizeSet(s); \
628 expectNumLatPoints(s, 1); \
629 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), \
630 true); \
631 }
632
633 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
634
635 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
636
637 // TODO: mult-dim tests
638
639 // restore warning status
640 #if defined(_MSC_VER) && !defined(__clang__)
641 #pragma warning(pop)
642 #endif
643