1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 2 #include "llvm/Support/Compiler.h" 3 #include "gmock/gmock.h" 4 #include "gtest/gtest.h" 5 #include <memory> 6 7 using namespace mlir; 8 using namespace mlir::sparse_tensor; 9 10 // Silence 'warning C4002: 'too many arguments for function-liked macro 11 // invocation' 12 // as MSVC handles ##__VA_ARGS__ differently as gcc/clang 13 14 #if defined(_MSC_VER) && !defined(__clang__) 15 #pragma warning(push) 16 #pragma warning(disable : 4002) 17 #endif 18 19 namespace { 20 21 /// 22 /// Defines macros to iterate binary and the combination of binary operations. 23 /// 24 25 #define FOREVERY_BINOP(DO) \ 26 DO(mulf, Kind::kMulF) \ 27 DO(mulc, Kind::kMulC) \ 28 DO(muli, Kind::kMulI) \ 29 DO(addf, Kind::kAddF) \ 30 DO(addc, Kind::kAddC) \ 31 DO(addi, Kind::kAddI) \ 32 DO(subf, Kind::kSubF) \ 33 DO(subc, Kind::kSubC) \ 34 DO(subi, Kind::kSubI) \ 35 DO(andi, Kind::kAndI) \ 36 DO(xori, Kind::kXorI) \ 37 DO(ori, Kind::kOrI) 38 39 // TODO: Disjunctive binary operations that need special handling are not 40 // included, e.g., Division are not tested (for now) as it need a constant 41 // non-zero dividend. 42 // ##__VA_ARGS__ handles cases when __VA_ARGS__ is empty. 43 #define FOREVERY_COMMON_DISJ_BINOP(TEST, ...) \ 44 TEST(addf, ##__VA_ARGS__) \ 45 TEST(addc, ##__VA_ARGS__) \ 46 TEST(addi, ##__VA_ARGS__) \ 47 TEST(xori, ##__VA_ARGS__) \ 48 TEST(ori, ##__VA_ARGS__) 49 50 // TODO: Conjunctive binary operations that need special handling are not 51 // included, e.g., substraction yields a different pattern as it is mapped to 52 // negate operation. 53 #define FOREVERY_COMMON_CONJ_BINOP(TEST, ...) \ 54 TEST(mulf, ##__VA_ARGS__) \ 55 TEST(mulc, ##__VA_ARGS__) \ 56 TEST(muli, ##__VA_ARGS__) \ 57 TEST(andi, ##__VA_ARGS__) 58 59 #define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST) \ 60 FOREVERY_COMMON_CONJ_BINOP(TEST, addf) \ 61 FOREVERY_COMMON_CONJ_BINOP(TEST, addc) \ 62 FOREVERY_COMMON_CONJ_BINOP(TEST, addi) \ 63 FOREVERY_COMMON_CONJ_BINOP(TEST, xori) \ 64 FOREVERY_COMMON_CONJ_BINOP(TEST, ori) 65 66 #define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST) \ 67 FOREVERY_COMMON_CONJ_BINOP(TEST, mulf) \ 68 FOREVERY_COMMON_CONJ_BINOP(TEST, mulc) \ 69 FOREVERY_COMMON_CONJ_BINOP(TEST, muli) \ 70 FOREVERY_COMMON_CONJ_BINOP(TEST, andi) 71 72 #define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST) \ 73 FOREVERY_COMMON_DISJ_BINOP(TEST, addf) \ 74 FOREVERY_COMMON_DISJ_BINOP(TEST, addc) \ 75 FOREVERY_COMMON_DISJ_BINOP(TEST, addi) \ 76 FOREVERY_COMMON_DISJ_BINOP(TEST, ori) \ 77 FOREVERY_COMMON_DISJ_BINOP(TEST, xori) 78 79 /// 80 /// Helper classes/functions for testing Merger. 81 /// 82 83 /// Simple recursive data structure used to match expressions in Mergers. 84 struct Pattern { 85 Kind kind; 86 87 /// Expressions representing tensors simply have a tensor number. 88 unsigned tensorNum; 89 90 /// Tensor operations point to their children. 91 std::shared_ptr<Pattern> e0; 92 std::shared_ptr<Pattern> e1; 93 94 /// Constructors. 95 /// Rather than using these, please use the readable helper constructor 96 /// functions below to make tests more readable. 97 Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} 98 Pattern(Kind kind, const std::shared_ptr<Pattern> &e0, 99 const std::shared_ptr<Pattern> &e1) 100 : kind(kind), e0(e0), e1(e1) { 101 assert(kind >= Kind::kMulF); 102 assert(e0 && e1); 103 } 104 }; 105 106 /// 107 /// Readable Pattern builder functions. 108 /// These should be preferred over the actual constructors. 109 /// 110 111 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) { 112 return std::make_shared<Pattern>(tensorNum); 113 } 114 115 #define IMPL_BINOP_PATTERN(OP, KIND) \ 116 LLVM_ATTRIBUTE_UNUSED static std::shared_ptr<Pattern> OP##Pattern( \ 117 const std::shared_ptr<Pattern> &e0, \ 118 const std::shared_ptr<Pattern> &e1) { \ 119 return std::make_shared<Pattern>(KIND, e0, e1); \ 120 } 121 122 FOREVERY_BINOP(IMPL_BINOP_PATTERN) 123 124 #undef IMPL_BINOP_PATTERN 125 126 class MergerTestBase : public ::testing::Test { 127 protected: 128 MergerTestBase(unsigned numTensors, unsigned numLoops) 129 : numTensors(numTensors), numLoops(numLoops), 130 merger(numTensors, numLoops) {} 131 132 /// 133 /// Expression construction helpers. 134 /// 135 136 unsigned tensor(unsigned tensor) { 137 return merger.addExp(Kind::kTensor, tensor); 138 } 139 140 #define IMPL_BINOP_EXPR(OP, KIND) \ 141 LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) { \ 142 return merger.addExp(KIND, e0, e1); \ 143 } 144 145 FOREVERY_BINOP(IMPL_BINOP_EXPR) 146 147 #undef IMPL_BINOP_EXPR 148 149 /// 150 /// Comparison helpers. 151 /// 152 153 /// For readability of tests. 154 unsigned lat(unsigned lat) { return lat; } 155 156 /// Returns true if a lattice point with an expression matching the given 157 /// pattern and bits matching the given bits is present in lattice points 158 /// [p, p+n) of lattice set s. This is useful for testing partial ordering 159 /// constraints between lattice points. We generally know how contiguous 160 /// groups of lattice points should be ordered with respect to other groups, 161 /// but there is no required ordering within groups. 162 /// If simple is true, then compare the lat.simple field instead to test the 163 /// result after optimization 164 bool latPointWithinRange(unsigned s, unsigned p, unsigned n, 165 const std::shared_ptr<Pattern> &pattern, 166 const BitVector &bits, bool simple) { 167 for (unsigned i = p; i < p + n; ++i) { 168 if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && 169 compareBits(s, i, bits, simple)) 170 return true; 171 } 172 return false; 173 } 174 175 /// Wrapper over latPointWithinRange for readability of tests. 176 void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, 177 const std::shared_ptr<Pattern> &pattern, 178 const BitVector &bits, bool simple = false) { 179 EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple)); 180 } 181 182 /// Wrapper over expectLatPointWithinRange for a single lat point. 183 void expectLatPoint(unsigned s, unsigned p, 184 const std::shared_ptr<Pattern> &pattern, 185 const BitVector &bits, bool simple = false) { 186 EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple)); 187 } 188 189 /// Converts a vector of (loop, tensor) pairs to a bitvector with the 190 /// corresponding bits set. 191 BitVector 192 loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) { 193 BitVector testBits = BitVector(numTensors + 1, false); 194 for (auto l : loops) { 195 auto loop = std::get<0>(l); 196 auto tensor = std::get<1>(l); 197 testBits.set(numTensors * loop + tensor); 198 } 199 return testBits; 200 } 201 202 /// Returns true if the bits of lattice point p in set s match the given bits. 203 /// If simple is true, then compare the lat.simple field instead to test the 204 /// result after optimization 205 bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) { 206 if (simple) 207 return merger.lat(merger.set(s)[p]).simple == bits; 208 return merger.lat(merger.set(s)[p]).bits == bits; 209 } 210 211 /// Check that there are n lattice points in set s. 212 void expectNumLatPoints(unsigned s, unsigned n) { 213 EXPECT_THAT(merger.set(s).size(), n); 214 } 215 216 /// Compares expressions for equality. Equality is defined recursively as: 217 /// - Operations are equal if they have the same kind and children. 218 /// - Leaf tensors are equal if they refer to the same tensor. 219 bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) { 220 auto tensorExp = merger.exp(e); 221 if (tensorExp.kind != pattern->kind) 222 return false; 223 switch (tensorExp.kind) { 224 // Leaf. 225 case kTensor: 226 return tensorExp.tensor == pattern->tensorNum; 227 case kInvariant: 228 case kIndex: 229 llvm_unreachable("invariant not handled yet"); 230 // Unary operations. 231 case kAbsF: 232 case kAbsC: 233 case kCeilF: 234 case kFloorF: 235 case kSqrtF: 236 case kSqrtC: 237 case kExpm1F: 238 case kExpm1C: 239 case kLog1pF: 240 case kLog1pC: 241 case kSinF: 242 case kSinC: 243 case kTanhF: 244 case kTanhC: 245 case kNegF: 246 case kNegC: 247 case kNegI: 248 case kTruncF: 249 case kExtF: 250 case kCastFS: 251 case kCastFU: 252 case kCastSF: 253 case kCastUF: 254 case kCastS: 255 case kCastU: 256 case kCastIdx: 257 case kTruncI: 258 case kCIm: 259 case kCRe: 260 case kBitCast: 261 case kBinaryBranch: 262 case kUnary: 263 case kShlI: 264 case kBinary: 265 return compareExpression(tensorExp.children.e0, pattern->e0); 266 // Binary operations. 267 case kMulF: 268 case kMulC: 269 case kMulI: 270 case kDivF: 271 case kDivC: 272 case kDivS: 273 case kDivU: 274 case kAddF: 275 case kAddC: 276 case kAddI: 277 case kSubF: 278 case kSubC: 279 case kSubI: 280 case kAndI: 281 case kOrI: 282 case kXorI: 283 case kShrS: 284 case kShrU: 285 return compareExpression(tensorExp.children.e0, pattern->e0) && 286 compareExpression(tensorExp.children.e1, pattern->e1); 287 } 288 llvm_unreachable("unexpected kind"); 289 } 290 291 unsigned numTensors; 292 unsigned numLoops; 293 Merger merger; 294 }; 295 296 /// 297 /// Tests with all sparse inputs. 298 /// 299 300 class MergerTest3T1L : public MergerTestBase { 301 protected: 302 // Our three tensors (two inputs, one output). 303 const unsigned t0 = 0, t1 = 1, t2 = 2; 304 305 // Our single loop. 306 const unsigned l0 = 0; 307 308 MergerTest3T1L() : MergerTestBase(3, 1) { 309 // Tensor 0: sparse input vector. 310 merger.addExp(Kind::kTensor, t0, -1u); 311 merger.setDim(t0, l0, Dim::kSparse); 312 313 // Tensor 1: sparse input vector. 314 merger.addExp(Kind::kTensor, t1, -1u); 315 merger.setDim(t1, l0, Dim::kSparse); 316 317 // Tensor 2: dense output vector. 318 merger.addExp(Kind::kTensor, t2, -1u); 319 merger.setDim(t2, l0, Dim::kDense); 320 } 321 }; 322 323 class MergerTest4T1L : public MergerTestBase { 324 protected: 325 // Our four tensors (three inputs, one output). 326 const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3; 327 328 // Our single loop. 329 const unsigned l0 = 0; 330 331 MergerTest4T1L() : MergerTestBase(4, 1) { 332 // Tensor 0: sparse input vector. 333 merger.addExp(Kind::kTensor, t0, -1u); 334 merger.setDim(t0, l0, Dim::kSparse); 335 336 // Tensor 1: sparse input vector. 337 merger.addExp(Kind::kTensor, t1, -1u); 338 merger.setDim(t1, l0, Dim::kSparse); 339 340 // Tensor 2: sparse input vector 341 merger.addExp(Kind::kTensor, t2, -1u); 342 merger.setDim(t2, l0, Dim::kSparse); 343 344 // Tensor 3: dense output vector 345 merger.addExp(Kind::kTensor, t3, -1u); 346 merger.setDim(t3, l0, Dim::kDense); 347 } 348 }; 349 350 /// 351 /// Tests with both sparse and dense input. 352 /// 353 354 class MergerTest3T1LD : public MergerTestBase { 355 protected: 356 // Our three tensors (two inputs, one output). 357 const unsigned t0 = 0, t1 = 1, t2 = 2; 358 359 // Our single loop. 360 const unsigned l0 = 0; 361 362 MergerTest3T1LD() : MergerTestBase(3, 1) { 363 // Tensor 0: sparse input vector. 364 merger.addExp(Kind::kTensor, t0, -1u); 365 merger.setDim(t0, l0, Dim::kSparse); 366 367 // Tensor 1: dense input vector. 368 merger.addExp(Kind::kTensor, t1, -1u); 369 merger.setDim(t1, l0, Dim::kDense); 370 371 // Tensor 2: dense output vector. 372 merger.addExp(Kind::kTensor, t2, -1u); 373 merger.setDim(t2, l0, Dim::kDense); 374 } 375 }; 376 377 } // namespace 378 379 /// Vector addition (disjunction) of 2 vectors. i.e.; 380 /// a(i) = b(i) + c(i) 381 /// which should form the 3 lattice points 382 /// { 383 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 384 /// lat( i_00 / tensor_0 ) 385 /// lat( i_01 / tensor_1 ) 386 /// } 387 /// and after optimization, the lattice points do not change (as there is no 388 /// duplicated point and all input vectors are sparse vector). 389 /// { 390 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 391 /// lat( i_00 / tensor_0 ) 392 /// lat( i_01 / tensor_1 ) 393 /// } 394 #define IMPL_MERGER_TEST_DISJ(OP) \ 395 TEST_F(MergerTest3T1L, vector_##OP) { \ 396 auto e = OP##Expr(tensor(t0), tensor(t1)); \ 397 auto p0 = tensorPattern(t0); \ 398 auto p1 = tensorPattern(t1); \ 399 auto s = merger.buildLattices(e, l0); \ 400 \ 401 expectNumLatPoints(s, 3); \ 402 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 403 loopsToBits({{l0, t0}, {l0, t1}})); \ 404 expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \ 405 expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \ 406 \ 407 s = merger.optimizeSet(s); \ 408 expectNumLatPoints(s, 3); \ 409 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 410 loopsToBits({{l0, t0}, {l0, t1}}), true); \ 411 expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}), \ 412 true); \ 413 expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}), \ 414 true); \ 415 } 416 417 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ) 418 419 #undef IMPL_MERGER_TEST_DISJ 420 421 /// Vector multiplication (conjunction) of 2 vectors, i.e.; 422 /// a(i) = b(i) * c(i) 423 /// which should form the single lattice point 424 /// { 425 /// lat( i_00 i_01 / (tensor_0 * tensor_1) ) 426 /// } 427 #define IMPL_MERGER_TEST_CONJ(OP) \ 428 TEST_F(MergerTest3T1L, vector_##OP) { \ 429 auto e = OP##Expr(t0, t1); \ 430 auto p0 = tensorPattern(t0); \ 431 auto p1 = tensorPattern(t1); \ 432 auto s = merger.buildLattices(e, l0); \ 433 \ 434 expectNumLatPoints(s, 1); \ 435 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 436 loopsToBits({{l0, t0}, {l0, t1}})); \ 437 \ 438 s = merger.optimizeSet(s); \ 439 expectNumLatPoints(s, 1); \ 440 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 441 loopsToBits({{l0, t0}, {l0, t1}}), true); \ 442 } 443 444 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ) 445 446 #undef IMPL_MERGER_TEST_CONJ 447 448 /// Vector multiplication (conjunction) then addition (disjunction), i.e.; 449 /// a(i) = b(i) * c(i) + d(i); 450 /// which should form 451 /// { 452 /// lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 ) 453 /// lat( i_00 i_01 / tensor_0 * tensor_1 454 /// lat( i_02 / tensor_2 ) 455 /// } 456 #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \ 457 TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \ 458 auto em = CONJ##Expr(t0, t1); \ 459 auto e = DISJ##Expr(em, t2); \ 460 auto p0 = tensorPattern(t0); \ 461 auto p1 = tensorPattern(t1); \ 462 auto p2 = tensorPattern(t2); \ 463 auto s = merger.buildLattices(e, l0); \ 464 \ 465 expectNumLatPoints(s, 3); \ 466 expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \ 467 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ 468 expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \ 469 loopsToBits({{l0, t0}, {l0, t1}})); \ 470 expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \ 471 \ 472 s = merger.optimizeSet(s); \ 473 expectNumLatPoints(s, 3); \ 474 expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \ 475 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ 476 expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \ 477 loopsToBits({{l0, t0}, {l0, t1}})); \ 478 expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \ 479 } 480 481 FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ) 482 483 #undef IMPL_MERGER_TEST_CONJ_DISJ 484 485 /// Vector addition (disjunction) then addition (disjunction), i.e.; 486 /// a(i) = b(i) + c(i) + d(i) 487 /// which should form 488 /// { 489 /// lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 ) 490 /// lat( i_02 i_01 / tensor_2 + tensor_1 ) 491 /// lat( i_02 i_00 / tensor_2 + tensor_0 ) 492 /// lat( i_01 i_00 / tensor_1 + tensor_0 ) 493 /// lat( i_02 / tensor_2 ) 494 /// lat( i_01 / tensor_1 ) 495 /// lat( i_00 / tensor_0 ) 496 /// } 497 #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \ 498 TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \ 499 auto em = DISJ1##Expr(t0, t1); \ 500 auto e = DISJ2##Expr(em, t2); \ 501 auto p0 = tensorPattern(t0); \ 502 auto p1 = tensorPattern(t1); \ 503 auto p2 = tensorPattern(t2); \ 504 auto s = merger.buildLattices(e, l0); \ 505 \ 506 expectNumLatPoints(s, 7); \ 507 expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \ 508 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ 509 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \ 510 loopsToBits({{l0, t1}, {l0, t2}})); \ 511 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \ 512 loopsToBits({{l0, t0}, {l0, t2}})); \ 513 expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \ 514 loopsToBits({{l0, t0}, {l0, t1}})); \ 515 expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \ 516 expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \ 517 expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \ 518 \ 519 s = merger.optimizeSet(s); \ 520 expectNumLatPoints(s, 7); \ 521 expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \ 522 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ 523 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \ 524 loopsToBits({{l0, t1}, {l0, t2}})); \ 525 expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \ 526 loopsToBits({{l0, t0}, {l0, t2}})); \ 527 expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \ 528 loopsToBits({{l0, t0}, {l0, t1}})); \ 529 expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \ 530 expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \ 531 expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \ 532 } 533 534 FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ) 535 536 #undef IMPL_MERGER_TEST_DISJ_DISJ 537 538 /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.; 539 /// a(i) = b(i) * c(i) * d(i); 540 /// which should form 541 /// { 542 /// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 ) 543 /// } 544 #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \ 545 TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \ 546 auto em = CONJ1##Expr(t0, t1); \ 547 auto e = CONJ2##Expr(em, t2); \ 548 auto p0 = tensorPattern(t0); \ 549 auto p1 = tensorPattern(t1); \ 550 auto p2 = tensorPattern(t2); \ 551 auto s = merger.buildLattices(e, l0); \ 552 expectNumLatPoints(s, 1); \ 553 expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ 554 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \ 555 s = merger.optimizeSet(s); \ 556 expectNumLatPoints(s, 1); \ 557 expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \ 558 loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \ 559 } 560 561 FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ) 562 563 #undef IMPL_MERGER_TEST_CONJ_CONJ 564 565 /// Vector addition (disjunction) of 2 vectors, i.e.; 566 /// a(i) = b(i) + c(i) 567 /// which should form the 3 lattice points 568 /// { 569 /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) 570 /// lat( i_00 / sparse_tensor_0 ) 571 /// lat( i_01 / dense_tensor_1 ) 572 /// } 573 /// which should be optimized to 574 /// { 575 /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton) 576 /// lat( i_01 / dense_tensor_0 ) (no sparse dimension) 577 /// } 578 /// 579 /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff 580 /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ). 581 #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP) \ 582 TEST_F(MergerTest3T1LD, vector_opted_##OP) { \ 583 auto e = OP##Expr(tensor(t0), tensor(t1)); \ 584 auto p0 = tensorPattern(t0); \ 585 auto p1 = tensorPattern(t1); \ 586 auto s = merger.buildLattices(e, l0); \ 587 \ 588 expectNumLatPoints(s, 3); \ 589 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 590 loopsToBits({{l0, t0}, {l0, t1}})); \ 591 expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \ 592 expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \ 593 \ 594 s = merger.optimizeSet(s); \ 595 expectNumLatPoints(s, 2); \ 596 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 597 loopsToBits({{l0, t0}, {l0, t1}}), true); \ 598 expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true); \ 599 } 600 601 FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ) 602 603 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ 604 605 /// Vector multiplication (conjunction) of 2 vectors, i.e.: 606 /// a(i) = b(i) * c(i) 607 /// which should form the single lattice point 608 /// { 609 /// lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) ) 610 /// } 611 /// it should be optimized to 612 /// { 613 /// lat( i_00 / (sparse_tensor_0 * dense_tensor_1) ) 614 /// } 615 /// since i_01 is a dense dimension. 616 #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP) \ 617 TEST_F(MergerTest3T1LD, vector_opted_##OP) { \ 618 auto e = OP##Expr(t0, t1); \ 619 auto p0 = tensorPattern(t0); \ 620 auto p1 = tensorPattern(t1); \ 621 auto s = merger.buildLattices(e, l0); \ 622 \ 623 expectNumLatPoints(s, 1); \ 624 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \ 625 loopsToBits({{l0, t0}, {l0, t1}})); \ 626 \ 627 s = merger.optimizeSet(s); \ 628 expectNumLatPoints(s, 1); \ 629 expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), \ 630 true); \ 631 } 632 633 FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ) 634 635 #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ 636 637 // TODO: mult-dim tests 638 639 // restore warning status 640 #if defined(_MSC_VER) && !defined(__clang__) 641 #pragma warning(pop) 642 #endif 643