1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 2 #include "llvm/Support/Compiler.h" 3 #include "gmock/gmock.h" 4 #include "gtest/gtest.h" 5 #include <memory> 6 7 using namespace mlir; 8 using namespace mlir::sparse_tensor; 9 10 namespace { 11 12 /// 13 /// Defines macros to iterate binary and the combination of binary operations. 14 /// 15 16 #define FOREVERY_BINOP(DO) \ 17 DO(mulf, Kind::kMulF) \ 18 DO(mulc, Kind::kMulC) \ 19 DO(muli, Kind::kMulI) \ 20 DO(addf, Kind::kAddF) \ 21 DO(addc, Kind::kAddC) \ 22 DO(addi, Kind::kAddI) \ 23 DO(subf, Kind::kSubF) \ 24 DO(subc, Kind::kSubC) \ 25 DO(subi, Kind::kSubI) \ 26 DO(andi, Kind::kAndI) \ 27 DO(xori, Kind::kXorI) \ 28 DO(ori, Kind::kOrI) 29 30 // TODO: Disjunctive binary operations that need special handling are not 31 // included, e.g., Division are not tested (for now) as it need a constant 32 // non-zero dividend. 33 // ##__VA_ARGS__ handles cases when __VA_ARGS__ is empty. 34 #define FOREVERY_COMMON_DISJ_BINOP(TEST, ...) \ 35 TEST(addf, ##__VA_ARGS__) \ 36 TEST(addc, ##__VA_ARGS__) \ 37 TEST(addi, ##__VA_ARGS__) \ 38 TEST(xori, ##__VA_ARGS__) \ 39 TEST(ori, ##__VA_ARGS__) 40 41 // TODO: Conjunctive binary operations that need special handling are not 42 // included, e.g., substraction yields a different pattern as it is mapped to 43 // negate operation. 44 #define FOREVERY_COMMON_CONJ_BINOP(TEST, ...) \ 45 TEST(mulf, ##__VA_ARGS__) \ 46 TEST(mulc, ##__VA_ARGS__) \ 47 TEST(muli, ##__VA_ARGS__) \ 48 TEST(andi, ##__VA_ARGS__) 49 50 #define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST) \ 51 FOREVERY_COMMON_CONJ_BINOP(TEST, addf) \ 52 FOREVERY_COMMON_CONJ_BINOP(TEST, addc) \ 53 FOREVERY_COMMON_CONJ_BINOP(TEST, addi) \ 54 FOREVERY_COMMON_CONJ_BINOP(TEST, xori) \ 55 FOREVERY_COMMON_CONJ_BINOP(TEST, ori) 56 57 #define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST) \ 58 FOREVERY_COMMON_CONJ_BINOP(TEST, mulf) \ 59 FOREVERY_COMMON_CONJ_BINOP(TEST, mulc) \ 60 FOREVERY_COMMON_CONJ_BINOP(TEST, muli) \ 61 FOREVERY_COMMON_CONJ_BINOP(TEST, andi) 62 63 #define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST) \ 64 FOREVERY_COMMON_DISJ_BINOP(TEST, addf) \ 65 FOREVERY_COMMON_DISJ_BINOP(TEST, addc) \ 66 FOREVERY_COMMON_DISJ_BINOP(TEST, addi) \ 67 FOREVERY_COMMON_DISJ_BINOP(TEST, ori) \ 68 FOREVERY_COMMON_DISJ_BINOP(TEST, xori) 69 70 /// 71 /// Helper classes/functions for testing Merger. 72 /// 73 74 /// Simple recursive data structure used to match expressions in Mergers. 75 struct Pattern { 76 Kind kind; 77 78 /// Expressions representing tensors simply have a tensor number. 79 unsigned tensorNum; 80 81 /// Tensor operations point to their children. 82 std::shared_ptr<Pattern> e0; 83 std::shared_ptr<Pattern> e1; 84 85 /// Constructors. 86 /// Rather than using these, please use the readable helper constructor 87 /// functions below to make tests more readable. 88 Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} 89 Pattern(Kind kind, const std::shared_ptr<Pattern> &e0, 90 const std::shared_ptr<Pattern> &e1) 91 : kind(kind), e0(e0), e1(e1) { 92 assert(kind >= Kind::kMulF); 93 assert(e0 && e1); 94 } 95 }; 96 97 /// 98 /// Readable Pattern builder functions. 99 /// These should be preferred over the actual constructors. 100 /// 101 102 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) { 103 return std::make_shared<Pattern>(tensorNum); 104 } 105 106 #define IMPL_BINOP_PATTERN(OP, KIND) \ 107 LLVM_ATTRIBUTE_UNUSED static std::shared_ptr<Pattern> OP##Pattern( \ 108 const std::shared_ptr<Pattern> &e0, \ 109 const std::shared_ptr<Pattern> &e1) { \ 110 return std::make_shared<Pattern>(KIND, e0, e1); \ 111 } 112 113 FOREVERY_BINOP(IMPL_BINOP_PATTERN) 114 115 #undef IMPL_BINOP_PATTERN 116 117 class MergerTestBase : public ::testing::Test { 118 protected: 119 MergerTestBase(unsigned numTensors, unsigned numLoops) 120 : numTensors(numTensors), numLoops(numLoops), 121 merger(numTensors, numLoops) {} 122 123 /// 124 /// Expression construction helpers. 125 /// 126 127 unsigned tensor(unsigned tensor) { 128 return merger.addExp(Kind::kTensor, tensor); 129 } 130 131 #define IMPL_BINOP_EXPR(OP, KIND) \ 132 LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) { \ 133 return merger.addExp(KIND, e0, e1); \ 134 } 135 136 FOREVERY_BINOP(IMPL_BINOP_EXPR) 137 138 #undef IMPL_BINOP_EXPR 139 140 /// 141 /// Comparison helpers. 142 /// 143 144 /// For readability of tests. 145 unsigned lat(unsigned lat) { return lat; } 146 147 /// Returns true if a lattice point with an expression matching the given 148 /// pattern and bits matching the given bits is present in lattice points 149 /// [p, p+n) of lattice set s. This is useful for testing partial ordering 150 /// constraints between lattice points. We generally know how contiguous 151 /// groups of lattice points should be ordered with respect to other groups, 152 /// but there is no required ordering within groups. 153 /// If simple is true, then compare the lat.simple field instead to test the 154 /// result after optimization 155 bool latPointWithinRange(unsigned s, unsigned p, unsigned n, 156 const std::shared_ptr<Pattern> &pattern, 157 const BitVector &bits, bool simple) { 158 for (unsigned i = p; i < p + n; ++i) { 159 if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && 160 compareBits(s, i, bits, simple)) 161 return true; 162 } 163 return false; 164 } 165 166 /// Wrapper over latPointWithinRange for readability of tests. 167 void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, 168 const std::shared_ptr<Pattern> &pattern, 169 const BitVector &bits, bool simple = false) { 170 EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple)); 171 } 172 173 /// Wrapper over expectLatPointWithinRange for a single lat point. 174 void expectLatPoint(unsigned s, unsigned p, 175 const std::shared_ptr<Pattern> &pattern, 176 const BitVector &bits, bool simple = false) { 177 EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple)); 178 } 179 180 /// Converts a vector of (loop, tensor) pairs to a bitvector with the 181 /// corresponding bits set. 182 BitVector 183 loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) { 184 BitVector testBits = BitVector(numTensors + 1, false); 185 for (auto l : loops) { 186 auto loop = std::get<0>(l); 187 auto tensor = std::get<1>(l); 188 testBits.set(numTensors * loop + tensor); 189 } 190 return testBits; 191 } 192 193 /// Returns true if the bits of lattice point p in set s match the given bits. 194 /// If simple is true, then compare the lat.simple field instead to test the 195 /// result after optimization 196 bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) { 197 if (simple) 198 return merger.lat(merger.set(s)[p]).simple == bits; 199 return merger.lat(merger.set(s)[p]).bits == bits; 200 } 201 202 /// Check that there are n lattice points in set s. 203 void expectNumLatPoints(unsigned s, unsigned n) { 204 EXPECT_THAT(merger.set(s).size(), n); 205 } 206 207 /// Compares expressions for equality. Equality is defined recursively as: 208 /// - Operations are equal if they have the same kind and children. 209 /// - Leaf tensors are equal if they refer to the same tensor. 210 bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) { 211 auto tensorExp = merger.exp(e); 212 if (tensorExp.kind != pattern->kind) 213 return false; 214 switch (tensorExp.kind) { 215 // Leaf. 216 case kTensor: 217 return tensorExp.tensor == pattern->tensorNum; 218 case kInvariant: 219 case kIndex: 220 llvm_unreachable("invariant not handled yet"); 221 // Unary operations. 222 case kAbsF: 223 case kAbsC: 224 case kCeilF: 225 case kFloorF: 226 case kSqrtF: 227 case kSqrtC: 228 case kExpm1F: 229 case kExpm1C: 230 case kLog1pF: 231 case kLog1pC: 232 case kSinF: 233 case kSinC: 234 case kTanhF: 235 case kTanhC: 236 case kNegF: 237 case kNegC: 238 case kNegI: 239 case kTruncF: 240 case kExtF: 241 case kCastFS: 242 case kCastFU: 243 case kCastSF: 244 case kCastUF: 245 case kCastS: 246 case kCastU: 247 case kCastIdx: 248 case kTruncI: 249 case kCIm: 250 case kCRe: 251 case kBitCast: 252 case kBinaryBranch: 253 case kUnary: 254 case kShlI: 255 case kBinary: 256 return compareExpression(tensorExp.children.e0, pattern->e0); 257 // Binary operations. 258 case kMulF: 259 case kMulC: 260 case kMulI: 261 case kDivF: 262 case kDivC: 263 case kDivS: 264 case kDivU: 265 case kAddF: 266 case kAddC: 267 case kAddI: 268 case kSubF: 269 case kSubC: 270 case kSubI: 271 case kAndI: 272 case kOrI: 273 case kXorI: 274 case kShrS: 275 case kShrU: 276 return compareExpression(tensorExp.children.e0, pattern->e0) && 277 compareExpression(tensorExp.children.e1, pattern->e1); 278 } 279 llvm_unreachable("unexpected kind"); 280 } 281 282 unsigned numTensors; 283 unsigned numLoops; 284 Merger merger; 285 }; 286 287 /// 288 /// Tests with all sparse inputs. 289 /// 290 291 class MergerTest3T1L : public MergerTestBase { 292 protected: 293 // Our three tensors (two inputs, one output). 294 const unsigned t0 = 0, t1 = 1, t2 = 2; 295 296 // Our single loop. 297 const unsigned l0 = 0; 298 299 MergerTest3T1L() : MergerTestBase(3, 1) { 300 // Tensor 0: sparse input vector. 301 merger.addExp(Kind::kTensor, t0, -1u); 302 merger.setDim(t0, l0, Dim::kSparse); 303 304 // Tensor 1: sparse input vector. 305 merger.addExp(Kind::kTensor, t1, -1u); 306 merger.setDim(t1, l0, Dim::kSparse); 307 308 // Tensor 2: dense output vector. 309 merger.addExp(Kind::kTensor, t2, -1u); 310 merger.setDim(t2, l0, Dim::kDense); 311 } 312 }; 313 314 class MergerTest4T1L : public MergerTestBase { 315 protected: 316 // Our four tensors (three inputs, one output). 317 const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3; 318 319 // Our single loop. 320 const unsigned l0 = 0; 321 322 MergerTest4T1L() : MergerTestBase(4, 1) { 323 // Tensor 0: sparse input vector. 324 merger.addExp(Kind::kTensor, t0, -1u); 325 merger.setDim(t0, l0, Dim::kSparse); 326 327 // Tensor 1: sparse input vector. 328 merger.addExp(Kind::kTensor, t1, -1u); 329 merger.setDim(t1, l0, Dim::kSparse); 330 331 // Tensor 2: sparse input vector 332 merger.addExp(Kind::kTensor, t2, -1u); 333 merger.setDim(t2, l0, Dim::kSparse); 334 335 // Tensor 3: dense output vector 336 merger.addExp(Kind::kTensor, t3, -1u); 337 merger.setDim(t3, l0, Dim::kDense); 338 } 339 }; 340 341 /// 342 /// Tests with both sparse and dense input. 343 /// 344 345 class MergerTest3T1LD : public MergerTestBase { 346 protected: 347 // Our three tensors (two inputs, one output). 348 const unsigned t0 = 0, t1 = 1, t2 = 2; 349 350 // Our single loop. 351 const unsigned l0 = 0; 352 353 MergerTest3T1LD() : MergerTestBase(3, 1) { 354 // Tensor 0: sparse input vector. 355 merger.addExp(Kind::kTensor, t0, -1u); 356 merger.setDim(t0, l0, Dim::kSparse); 357 358 // Tensor 1: dense input vector. 359 merger.addExp(Kind::kTensor, t1, -1u); 360 merger.setDim(t1, l0, Dim::kDense); 361 362 // Tensor 2: dense output vector. 363 merger.addExp(Kind::kTensor, t2, -1u); 364 merger.setDim(t2, l0, Dim::kDense); 365 } 366 }; 367 368 } // namespace 369 370 /// Vector addition (disjunction) of 2 vectors. i.e.; 371 /// a(i) = b(i) + c(i) 372 /// which should form the 3 lattice points 373 /// { 374 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 375 /// lat( i_00 / tensor_0 ) 376 /// lat( i_01 / tensor_1 ) 377 /// } 378 /// and after optimization, the lattice points do not change (as there is no 379 /// duplicated point and all input vectors are sparse vector). 380 /// { 381 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 382 /// lat( i_00 / tensor_0 ) 383 /// lat( i_01 / tensor_1 ) 384 /// } 385 #define IMPL_MERGER_TEST_DISJ(OP) \ 386 TEST_F(MergerTest3T1L, vector_##OP) { \ 387 auto e = OP##Expr(tensor(t0), tensor(t1)); \ 388 auto p0 = tensorPattern(t0); \ 389 auto p1 = tensorPattern(t1); \ 390 auto s = merger.buildLattices(e, l0); \ 391 \ 392 expectNumLatPoints(s, 3); \ 393 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 394 loopsToBits({{l0, t0}, {l0, t1}})); \ 395 expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \ 396 expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \ 397 \ 398 s = merger.optimizeSet(s); \ 399 expectNumLatPoints(s, 3); \ 400 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 401 loopsToBits({{l0, t0}, {l0, t1}}), true); \ 402 expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}), \ 403 true); \ 404 expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}), \ 405 true); \ 406 } 407 408 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ) 409 410 #undef IMPL_MERGER_TEST_DISJ 411 412 /// Vector multiplication (conjunction) of 2 vectors, i.e.; 413 /// a(i) = b(i) * c(i) 414 /// which should form the single lattice point 415 /// { 416 /// lat( i_00 i_01 / (tensor_0 * tensor_1) ) 417 /// } 418 #define IMPL_MERGER_TEST_CONJ(OP) \ 419 TEST_F(MergerTest3T1L, vector_##OP) { \ 420 auto e = OP##Expr(t0, t1); \ 421 auto p0 = tensorPattern(t0); \ 422 auto p1 = tensorPattern(t1); \ 423 auto s = merger.buildLattices(e, l0); \ 424 \ 425 expectNumLatPoints(s, 1); \ 426 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 427 loopsToBits({{l0, t0}, {l0, t1}})); \ 428 \ 429 s = merger.optimizeSet(s); \ 430 expectNumLatPoints(s, 1); \ 431 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 432 loopsToBits({{l0, t0}, {l0, t1}}), true); \ 433 } 434 435 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ) 436 437 #undef IMPL_MERGER_TEST_CONJ 438 439 /// Vector multiplication (conjunction) then addition (disjunction), i.e.; 440 /// a(i) = b(i) * c(i) + d(i); 441 /// which should form 442 /// { 443 /// lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 ) 444 /// lat( i_00 i_01 / tensor_0 * tensor_1 445 /// lat( i_02 / tensor_2 ) 446 /// } 447 #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \ 448 TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \ 449 auto em = CONJ##Expr(t0, t1); \ 450 auto e = DISJ##Expr(em, t2); \ 451 auto p0 = tensorPattern(t0); \ 452 auto p1 = tensorPattern(t1); \ 453 auto p2 = tensorPattern(t2); \ 454 auto s = merger.buildLattices(e, l0); \ 455 \ 456 expectNumLatPoints(s, 3); \ 457 expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \ 458 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ 459 expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \ 460 loopsToBits({{l0, t0}, {l0, t1}})); \ 461 expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \ 462 \ 463 s = merger.optimizeSet(s); \ 464 expectNumLatPoints(s, 3); \ 465 expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \ 466 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ 467 expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \ 468 loopsToBits({{l0, t0}, {l0, t1}})); \ 469 expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \ 470 } 471 472 FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ) 473 474 #undef IMPL_MERGER_TEST_CONJ_DISJ 475 476 /// Vector addition (disjunction) then addition (disjunction), i.e.; 477 /// a(i) = b(i) + c(i) + d(i) 478 /// which should form 479 /// { 480 /// lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 ) 481 /// lat( i_02 i_01 / tensor_2 + tensor_1 ) 482 /// lat( i_02 i_00 / tensor_2 + tensor_0 ) 483 /// lat( i_01 i_00 / tensor_1 + tensor_0 ) 484 /// lat( i_02 / tensor_2 ) 485 /// lat( i_01 / tensor_1 ) 486 /// lat( i_00 / tensor_0 ) 487 /// } 488 #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \ 489 TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \ 490 auto em = DISJ1##Expr(t0, t1); \ 491 auto e = DISJ2##Expr(em, t2); \ 492 auto p0 = tensorPattern(t0); \ 493 auto p1 = tensorPattern(t1); \ 494 auto p2 = tensorPattern(t2); \ 495 auto s = merger.buildLattices(e, l0); \ 496 \ 497 expectNumLatPoints(s, 7); \ 498 expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \ 499 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ 500 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \ 501 loopsToBits({{l0, t1}, {l0, t2}})); \ 502 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \ 503 loopsToBits({{l0, t0}, {l0, t2}})); \ 504 expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \ 505 loopsToBits({{l0, t0}, {l0, t1}})); \ 506 expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \ 507 expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \ 508 expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \ 509 \ 510 s = merger.optimizeSet(s); \ 511 expectNumLatPoints(s, 7); \ 512 expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \ 513 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ 514 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \ 515 loopsToBits({{l0, t1}, {l0, t2}})); \ 516 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \ 517 loopsToBits({{l0, t0}, {l0, t2}})); \ 518 expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \ 519 loopsToBits({{l0, t0}, {l0, t1}})); \ 520 expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \ 521 expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \ 522 expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \ 523 } 524 525 FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ) 526 527 #undef IMPL_MERGER_TEST_DISJ_DISJ 528 529 /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.; 530 /// a(i) = b(i) * c(i) * d(i); 531 /// which should form 532 /// { 533 /// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 ) 534 /// } 535 #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \ 536 TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \ 537 auto em = CONJ1##Expr(t0, t1); \ 538 auto e = CONJ2##Expr(em, t2); \ 539 auto p0 = tensorPattern(t0); \ 540 auto p1 = tensorPattern(t1); \ 541 auto p2 = tensorPattern(t2); \ 542 auto s = merger.buildLattices(e, l0); \ 543 expectNumLatPoints(s, 1); \ 544 expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ 545 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ 546 s = merger.optimizeSet(s); \ 547 expectNumLatPoints(s, 1); \ 548 expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ 549 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \ 550 } 551 552 FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ) 553 554 #undef IMPL_MERGER_TEST_CONJ_CONJ 555 556 /// Vector addition (disjunction) of 2 vectors, i.e.; 557 /// a(i) = b(i) + c(i) 558 /// which should form the 3 lattice points 559 /// { 560 /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) 561 /// lat( i_00 / sparse_tensor_0 ) 562 /// lat( i_01 / dense_tensor_1 ) 563 /// } 564 /// which should be optimized to 565 /// { 566 /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton) 567 /// lat( i_01 / dense_tensor_0 ) (no sparse dimension) 568 /// } 569 /// 570 /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff 571 /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ). 572 #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP) \ 573 TEST_F(MergerTest3T1LD, vector_opted_##OP) { \ 574 auto e = OP##Expr(tensor(t0), tensor(t1)); \ 575 auto p0 = tensorPattern(t0); \ 576 auto p1 = tensorPattern(t1); \ 577 auto s = merger.buildLattices(e, l0); \ 578 \ 579 expectNumLatPoints(s, 3); \ 580 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 581 loopsToBits({{l0, t0}, {l0, t1}})); \ 582 expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \ 583 expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \ 584 \ 585 s = merger.optimizeSet(s); \ 586 expectNumLatPoints(s, 2); \ 587 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 588 loopsToBits({{l0, t0}, {l0, t1}}), true); \ 589 expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true); \ 590 } 591 592 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ) 593 594 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ 595 596 /// Vector multiplication (conjunction) of 2 vectors, i.e.: 597 /// a(i) = b(i) * c(i) 598 /// which should form the single lattice point 599 /// { 600 /// lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) ) 601 /// } 602 /// it should be optimized to 603 /// { 604 /// lat( i_00 / (sparse_tensor_0 * dense_tensor_1) ) 605 /// } 606 /// since i_01 is a dense dimension. 607 #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP) \ 608 TEST_F(MergerTest3T1LD, vector_opted_##OP) { \ 609 auto e = OP##Expr(t0, t1); \ 610 auto p0 = tensorPattern(t0); \ 611 auto p1 = tensorPattern(t1); \ 612 auto s = merger.buildLattices(e, l0); \ 613 \ 614 expectNumLatPoints(s, 1); \ 615 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 616 loopsToBits({{l0, t0}, {l0, t1}})); \ 617 \ 618 s = merger.optimizeSet(s); \ 619 expectNumLatPoints(s, 1); \ 620 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), \ 621 true); \ 622 } 623 624 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ) 625 626 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ 627 628 // TODO: mult-dim tests 629