1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 18 namespace mlir { 19 namespace sparse_tensor { 20 21 //===----------------------------------------------------------------------===// 22 // Constructors. 23 //===----------------------------------------------------------------------===// 24 25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) 26 : kind(k), val(v), op(o) { 27 switch (kind) { 28 // Leaf. 29 case kTensor: 30 assert(x != -1u && y == -1u && !v && !o); 31 tensor = x; 32 break; 33 case kInvariant: 34 assert(x == -1u && y == -1u && v && !o); 35 break; 36 case kIndex: 37 assert(x != -1u && y == -1u && !v && !o); 38 index = x; 39 break; 40 // Unary operations. 41 case kAbsF: 42 case kAbsC: 43 case kCeilF: 44 case kFloorF: 45 case kSqrtF: 46 case kSqrtC: 47 case kExpm1F: 48 case kExpm1C: 49 case kLog1pF: 50 case kLog1pC: 51 case kSinF: 52 case kSinC: 53 case kTanhF: 54 case kTanhC: 55 case kNegF: 56 case kNegC: 57 case kNegI: 58 case kCIm: 59 case kCRe: 60 assert(x != -1u && y == -1u && !v && !o); 61 children.e0 = x; 62 children.e1 = y; 63 break; 64 case kTruncF: 65 case kExtF: 66 case kCastFS: 67 case kCastFU: 68 case kCastSF: 69 case kCastUF: 70 case kCastS: 71 case kCastU: 72 case kCastIdx: 73 case kTruncI: 74 case kBitCast: 75 assert(x != -1u && y == -1u && v && !o); 76 children.e0 = x; 77 children.e1 = y; 78 break; 79 case kBinaryBranch: 80 assert(x != -1u && y == -1u && !v && o); 81 children.e0 = x; 82 children.e1 = y; 83 break; 84 case kUnary: 85 // No assertion on y can be made, as the branching paths involve both 86 // a unary (mapSet) and binary (takeDisj) pathway. 87 assert(x != -1u && !v && o); 88 children.e0 = x; 89 children.e1 = y; 90 break; 91 // Binary operations. 92 case kMulF: 93 case kMulC: 94 case kMulI: 95 case kDivF: 96 case kDivC: 97 case kDivS: 98 case kDivU: 99 case kAddF: 100 case kAddC: 101 case kAddI: 102 case kSubF: 103 case kSubC: 104 case kSubI: 105 case kAndI: 106 case kOrI: 107 case kXorI: 108 case kShrS: 109 case kShrU: 110 case kShlI: 111 assert(x != -1u && y != -1u && !v && !o); 112 children.e0 = x; 113 children.e1 = y; 114 break; 115 case kBinary: 116 assert(x != -1u && y != -1u && !v && o); 117 children.e0 = x; 118 children.e1 = y; 119 break; 120 } 121 } 122 123 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 124 : bits(n, false), simple(), exp(e) { 125 bits.set(b); 126 } 127 128 LatPoint::LatPoint(const BitVector &b, unsigned e) 129 : bits(b), simple(), exp(e) {} 130 131 //===----------------------------------------------------------------------===// 132 // Lattice methods. 133 //===----------------------------------------------------------------------===// 134 135 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, 136 Operation *op) { 137 unsigned e = tensorExps.size(); 138 tensorExps.push_back(TensorExp(k, e0, e1, v, op)); 139 return e; 140 } 141 142 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 143 assert(t < numTensors && i < numLoops); 144 unsigned p = latPoints.size(); 145 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 146 return p; 147 } 148 149 unsigned Merger::addSet() { 150 unsigned s = latSets.size(); 151 latSets.emplace_back(SmallVector<unsigned, 16>()); 152 return s; 153 } 154 155 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, 156 Operation *op) { 157 unsigned p = latPoints.size(); 158 BitVector nb = BitVector(latPoints[p0].bits); 159 nb |= latPoints[p1].bits; 160 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); 161 latPoints.push_back(LatPoint(nb, e)); 162 return p; 163 } 164 165 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 166 unsigned s = addSet(); 167 for (unsigned p0 : latSets[s0]) 168 for (unsigned p1 : latSets[s1]) 169 latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); 170 return s; 171 } 172 173 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 174 unsigned s = takeConj(kind, s0, s1, op); 175 // Followed by all in s0. 176 for (unsigned p : latSets[s0]) 177 latSets[s].push_back(p); 178 // Map binary 0-y to unary -y. 179 // TODO: move this if-else logic into buildLattices 180 if (kind == kSubF) 181 s1 = mapSet(kNegF, s1); 182 else if (kind == kSubC) 183 s1 = mapSet(kNegC, s1); 184 else if (kind == kSubI) 185 s1 = mapSet(kNegI, s1); 186 // Followed by all in s1. 187 for (unsigned p : latSets[s1]) 188 latSets[s].push_back(p); 189 return s; 190 } 191 192 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, 193 bool includeLeft, Kind ltrans, Operation *opleft, 194 bool includeRight, Kind rtrans, Operation *opright) { 195 unsigned s = takeConj(kind, s0, s1, orig); 196 // Left Region. 197 if (includeLeft) { 198 if (opleft) 199 s0 = mapSet(ltrans, s0, Value(), opleft); 200 for (unsigned p : latSets[s0]) 201 latSets[s].push_back(p); 202 } 203 // Right Region. 204 if (includeRight) { 205 if (opright) 206 s1 = mapSet(rtrans, s1, Value(), opright); 207 for (unsigned p : latSets[s1]) 208 latSets[s].push_back(p); 209 } 210 return s; 211 } 212 213 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { 214 assert(kAbsF <= kind && kind <= kUnary); 215 unsigned s = addSet(); 216 for (unsigned p : latSets[s0]) { 217 unsigned e = addExp(kind, latPoints[p].exp, v, op); 218 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 219 latSets[s].push_back(latPoints.size() - 1); 220 } 221 return s; 222 } 223 224 unsigned Merger::optimizeSet(unsigned s0) { 225 unsigned s = addSet(); 226 assert(!latSets[s0].empty()); 227 unsigned p0 = latSets[s0][0]; 228 for (unsigned p1 : latSets[s0]) { 229 bool add = true; 230 if (p0 != p1) { 231 // Is this a straightforward copy? 232 unsigned e = latPoints[p1].exp; 233 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 234 continue; 235 // Conjunction already covered? 236 for (unsigned p2 : latSets[s]) { 237 assert(!latGT(p1, p2)); // Lj => Li would be bad 238 if (onlyDenseDiff(p2, p1)) { 239 add = false; 240 break; 241 } 242 } 243 assert(!add || latGT(p0, p1)); 244 } 245 if (add) 246 latSets[s].push_back(p1); 247 } 248 for (unsigned p : latSets[s]) 249 latPoints[p].simple = simplifyCond(s, p); 250 return s; 251 } 252 253 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 254 // First determine if this lattice point is a *singleton*, i.e., 255 // the last point in a lattice, no other is less than this one. 256 bool isSingleton = true; 257 for (unsigned p1 : latSets[s0]) { 258 if (p0 != p1 && latGT(p0, p1)) { 259 isSingleton = false; 260 break; 261 } 262 } 263 // Now apply the two basic rules. 264 BitVector simple = latPoints[p0].bits; 265 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 266 for (unsigned b = 0, be = simple.size(); b < be; b++) { 267 if (simple[b] && !isDim(b, kSparse)) { 268 if (reset) 269 simple.reset(b); 270 reset = true; 271 } 272 } 273 return simple; 274 } 275 276 bool Merger::latGT(unsigned i, unsigned j) const { 277 const BitVector &bitsi = latPoints[i].bits; 278 const BitVector &bitsj = latPoints[j].bits; 279 assert(bitsi.size() == bitsj.size()); 280 if (bitsi.count() > bitsj.count()) { 281 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 282 if (bitsj[b] && !bitsi[b]) 283 return false; 284 return true; 285 } 286 return false; 287 } 288 289 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 290 BitVector tmp = latPoints[j].bits; 291 tmp ^= latPoints[i].bits; 292 return !hasAnyDimOf(tmp, kSparse); 293 } 294 295 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { 296 for (unsigned b = 0, be = bits.size(); b < be; b++) 297 if (bits[b] && isDim(b, d)) 298 return true; 299 return false; 300 } 301 302 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 303 switch (tensorExps[e].kind) { 304 // Leaf. 305 case kTensor: 306 return tensorExps[e].tensor == t; 307 case kInvariant: 308 case kIndex: 309 return false; 310 // Unary operations. 311 case kAbsF: 312 case kAbsC: 313 case kCeilF: 314 case kFloorF: 315 case kSqrtF: 316 case kSqrtC: 317 case kExpm1F: 318 case kExpm1C: 319 case kLog1pF: 320 case kLog1pC: 321 case kSinF: 322 case kSinC: 323 case kTanhF: 324 case kTanhC: 325 case kNegF: 326 case kNegC: 327 case kNegI: 328 case kTruncF: 329 case kExtF: 330 case kCastFS: 331 case kCastFU: 332 case kCastSF: 333 case kCastUF: 334 case kCastS: 335 case kCastU: 336 case kCastIdx: 337 case kTruncI: 338 case kCIm: 339 case kCRe: 340 case kBitCast: 341 return isSingleCondition(t, tensorExps[e].children.e0); 342 case kBinaryBranch: 343 case kUnary: 344 return false; 345 // Binary operations. 346 case kDivF: // note: x / c only 347 case kDivC: 348 case kDivS: 349 case kDivU: 350 assert(!maybeZero(tensorExps[e].children.e1)); 351 return isSingleCondition(t, tensorExps[e].children.e0); 352 case kShrS: // note: x >> inv only 353 case kShrU: 354 case kShlI: 355 assert(isInvariant(tensorExps[e].children.e1)); 356 return isSingleCondition(t, tensorExps[e].children.e0); 357 case kMulF: 358 case kMulC: 359 case kMulI: 360 case kAndI: 361 if (isSingleCondition(t, tensorExps[e].children.e0)) 362 return isSingleCondition(t, tensorExps[e].children.e1) || 363 isInvariant(tensorExps[e].children.e1); 364 if (isSingleCondition(t, tensorExps[e].children.e1)) 365 return isInvariant(tensorExps[e].children.e0); 366 return false; 367 case kAddF: 368 case kAddC: 369 case kAddI: 370 return isSingleCondition(t, tensorExps[e].children.e0) && 371 isSingleCondition(t, tensorExps[e].children.e1); 372 case kSubF: 373 case kSubC: 374 case kSubI: 375 case kOrI: 376 case kXorI: 377 case kBinary: 378 return false; 379 } 380 } 381 382 #ifndef NDEBUG 383 384 //===----------------------------------------------------------------------===// 385 // Print methods (for debugging). 386 //===----------------------------------------------------------------------===// 387 388 static const char *kindToOpSymbol(Kind kind) { 389 switch (kind) { 390 // Leaf. 391 case kTensor: 392 return "tensor"; 393 case kInvariant: 394 return "invariant"; 395 case kIndex: 396 return "index"; 397 // Unary operations. 398 case kAbsF: 399 case kAbsC: 400 return "abs"; 401 case kCeilF: 402 return "ceil"; 403 case kFloorF: 404 return "floor"; 405 case kSqrtF: 406 case kSqrtC: 407 return "sqrt"; 408 case kExpm1F: 409 case kExpm1C: 410 return "expm1"; 411 case kLog1pF: 412 case kLog1pC: 413 return "log1p"; 414 case kSinF: 415 case kSinC: 416 return "sin"; 417 case kTanhF: 418 case kTanhC: 419 return "tanh"; 420 case kNegF: 421 case kNegC: 422 case kNegI: 423 return "-"; 424 case kTruncF: 425 case kExtF: 426 case kCastFS: 427 case kCastFU: 428 case kCastSF: 429 case kCastUF: 430 case kCastS: 431 case kCastU: 432 case kCastIdx: 433 case kTruncI: 434 case kCIm: 435 return "complex.im"; 436 case kCRe: 437 return "complex.re"; 438 case kBitCast: 439 return "cast"; 440 case kBinaryBranch: 441 return "binary_branch"; 442 case kUnary: 443 return "unary"; 444 // Binary operations. 445 case kMulF: 446 case kMulC: 447 case kMulI: 448 return "*"; 449 case kDivF: 450 case kDivC: 451 case kDivS: 452 case kDivU: 453 return "/"; 454 case kAddF: 455 case kAddC: 456 case kAddI: 457 return "+"; 458 case kSubF: 459 case kSubC: 460 case kSubI: 461 return "-"; 462 case kAndI: 463 return "&"; 464 case kOrI: 465 return "|"; 466 case kXorI: 467 return "^"; 468 case kShrS: 469 return "a>>"; 470 case kShrU: 471 return ">>"; 472 case kShlI: 473 return "<<"; 474 case kBinary: 475 return "binary"; 476 } 477 llvm_unreachable("unexpected kind for symbol"); 478 } 479 480 void Merger::dumpExp(unsigned e) const { 481 switch (tensorExps[e].kind) { 482 // Leaf. 483 case kTensor: 484 if (tensorExps[e].tensor == syntheticTensor) 485 llvm::dbgs() << "synthetic_"; 486 else if (tensorExps[e].tensor == outTensor) 487 llvm::dbgs() << "output_"; 488 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 489 break; 490 case kInvariant: 491 llvm::dbgs() << "invariant"; 492 break; 493 case kIndex: 494 llvm::dbgs() << "index_" << tensorExps[e].index; 495 break; 496 // Unary operations. 497 case kAbsF: 498 case kAbsC: 499 case kCeilF: 500 case kFloorF: 501 case kSqrtF: 502 case kSqrtC: 503 case kExpm1F: 504 case kExpm1C: 505 case kLog1pF: 506 case kLog1pC: 507 case kSinF: 508 case kSinC: 509 case kTanhF: 510 case kTanhC: 511 case kNegF: 512 case kNegC: 513 case kNegI: 514 case kTruncF: 515 case kExtF: 516 case kCastFS: 517 case kCastFU: 518 case kCastSF: 519 case kCastUF: 520 case kCastS: 521 case kCastU: 522 case kCastIdx: 523 case kTruncI: 524 case kCIm: 525 case kCRe: 526 case kBitCast: 527 case kBinaryBranch: 528 case kUnary: 529 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 530 dumpExp(tensorExps[e].children.e0); 531 break; 532 // Binary operations. 533 case kMulF: 534 case kMulC: 535 case kMulI: 536 case kDivF: 537 case kDivC: 538 case kDivS: 539 case kDivU: 540 case kAddF: 541 case kAddC: 542 case kAddI: 543 case kSubF: 544 case kSubC: 545 case kSubI: 546 case kAndI: 547 case kOrI: 548 case kXorI: 549 case kShrS: 550 case kShrU: 551 case kShlI: 552 case kBinary: 553 llvm::dbgs() << "("; 554 dumpExp(tensorExps[e].children.e0); 555 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 556 dumpExp(tensorExps[e].children.e1); 557 llvm::dbgs() << ")"; 558 } 559 } 560 561 void Merger::dumpLat(unsigned p) const { 562 llvm::dbgs() << "lat("; 563 dumpBits(latPoints[p].bits); 564 llvm::dbgs() << " :"; 565 dumpBits(latPoints[p].simple); 566 llvm::dbgs() << " : "; 567 dumpExp(latPoints[p].exp); 568 llvm::dbgs() << " )\n"; 569 } 570 571 void Merger::dumpSet(unsigned s) const { 572 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 573 for (unsigned p : latSets[s]) { 574 llvm::dbgs() << " "; 575 dumpLat(p); 576 } 577 llvm::dbgs() << "}\n"; 578 } 579 580 void Merger::dumpBits(const BitVector &bits) const { 581 for (unsigned b = 0, be = bits.size(); b < be; b++) { 582 if (bits[b]) { 583 unsigned t = tensor(b); 584 unsigned i = index(b); 585 llvm::dbgs() << " i_" << t << "_" << i << "_"; 586 switch (dims[t][i]) { 587 case kSparse: 588 llvm::dbgs() << "S"; 589 break; 590 case kDense: 591 llvm::dbgs() << "D"; 592 break; 593 case kSingle: 594 llvm::dbgs() << "T"; 595 break; 596 case kUndef: 597 llvm::dbgs() << "U"; 598 break; 599 } 600 } 601 } 602 } 603 604 #endif // NDEBUG 605 606 //===----------------------------------------------------------------------===// 607 // Builder methods. 608 //===----------------------------------------------------------------------===// 609 610 unsigned Merger::buildLattices(unsigned e, unsigned i) { 611 Kind kind = tensorExps[e].kind; 612 switch (kind) { 613 // Leaf. 614 case kTensor: 615 case kInvariant: 616 case kIndex: { 617 // Either the index is really used in the tensor expression, or it is 618 // set to the undefined index in that dimension. An invariant expression, 619 // a proper index value, and a truly dynamic sparse output tensor are set 620 // to a synthetic tensor with undefined indices only to ensure the 621 // iteration space is not skipped as a result of their contents. 622 unsigned s = addSet(); 623 unsigned t = syntheticTensor; 624 if (kind == kTensor) { 625 t = tensorExps[e].tensor; 626 if (hasSparseOut && t == outTensor) 627 t = syntheticTensor; 628 } 629 latSets[s].push_back(addLat(t, i, e)); 630 return s; 631 } 632 // Unary operations. 633 case kAbsF: 634 case kAbsC: 635 case kCeilF: 636 case kFloorF: 637 case kSqrtF: 638 case kSqrtC: 639 case kExpm1F: 640 case kExpm1C: 641 case kLog1pF: 642 case kLog1pC: 643 case kSinF: 644 case kSinC: 645 case kTanhF: 646 case kTanhC: 647 case kNegF: 648 case kNegC: 649 case kNegI: 650 case kTruncF: 651 case kExtF: 652 case kCastFS: 653 case kCastFU: 654 case kCastSF: 655 case kCastUF: 656 case kCastS: 657 case kCastU: 658 case kCastIdx: 659 case kTruncI: 660 case kCIm: 661 case kCRe: 662 case kBitCast: 663 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 664 // lattice set of the operand through the operator into a new set. 665 // 666 // -y|!y | y | 667 // --+---+---+ 668 // | 0 |-y | 669 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 670 tensorExps[e].val); 671 case kBinaryBranch: 672 // The left or right half of a binary operation which has already 673 // been split into separate operations for each region. 674 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), 675 tensorExps[e].op); 676 case kUnary: 677 // A custom unary operation. 678 // 679 // op y| !y | y | 680 // ----+----------+------------+ 681 // | absent() | present(y) | 682 { 683 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 684 UnaryOp unop = cast<UnaryOp>(tensorExps[e].op); 685 Region &absentRegion = unop.absentRegion(); 686 687 if (absentRegion.empty()) { 688 // Simple mapping over existing values. 689 return mapSet(kind, child0, Value(), unop); 690 } // Use a disjunction with `unop` on the left and the absent value as an 691 // invariant on the right. 692 Block &absentBlock = absentRegion.front(); 693 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 694 Value absentVal = absentYield.result(); 695 unsigned rhs = addExp(kInvariant, absentVal); 696 return takeDisj(kind, child0, buildLattices(rhs, i), unop); 697 } 698 // Binary operations. 699 case kMulF: 700 case kMulC: 701 case kMulI: 702 case kAndI: 703 // A multiplicative operation only needs to be performed 704 // for the conjunction of sparse iteration spaces. 705 // 706 // x*y|!y | y | 707 // ---+---+---+ 708 // !x | 0 | 0 | 709 // x | 0 |x*y| 710 // 711 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 712 return takeConj(kind, // take binary conjunction 713 buildLattices(tensorExps[e].children.e0, i), 714 buildLattices(tensorExps[e].children.e1, i)); 715 case kDivF: 716 case kDivC: 717 case kDivS: 718 case kDivU: 719 // A division is tricky, since 0/0, 0/c, c/0 all have 720 // specific outcomes for floating-point and integers. 721 // Thus, we need to traverse the full iteration space. 722 // 723 // x/y|!y | y | 724 // ---+---+---+ 725 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 726 // x |x/0|x/y| INT: x/0=exception for any x 727 // 728 // TODO: for now we "fixed" this by only accepting x/c cases 729 // during expression building, so that the conjunction 730 // rules applies (viz. x/c = x*(1/c) as far as lattice 731 // construction is concerned). 732 assert(!maybeZero(tensorExps[e].children.e1)); 733 return takeConj(kind, // take binary conjunction 734 buildLattices(tensorExps[e].children.e0, i), 735 buildLattices(tensorExps[e].children.e1, i)); 736 case kAddF: 737 case kAddC: 738 case kAddI: 739 case kSubF: 740 case kSubC: 741 case kSubI: 742 case kOrI: 743 case kXorI: 744 // An additive operation needs to be performed 745 // for the disjunction of sparse iteration spaces. 746 // 747 // x+y|!y | y | x-y|!y | y | 748 // ---+---+---+ ---+---+---+ 749 // !x | 0 | y | !x | 0 |-y | 750 // x | x |x+y| x | x |x-y| 751 return takeDisj(kind, // take binary disjunction 752 buildLattices(tensorExps[e].children.e0, i), 753 buildLattices(tensorExps[e].children.e1, i)); 754 case kShrS: 755 case kShrU: 756 case kShlI: 757 // A shift operation by an invariant amount (viz. tensor expressions 758 // can only occur at the left-hand-side of the operator) can be handled 759 // with the conjuction rule. 760 assert(isInvariant(tensorExps[e].children.e1)); 761 return takeConj(kind, // take binary conjunction 762 buildLattices(tensorExps[e].children.e0, i), 763 buildLattices(tensorExps[e].children.e1, i)); 764 case kBinary: 765 // A custom binary operation. 766 // 767 // x op y| !y | y | 768 // ------+---------+--------------+ 769 // !x | empty | right(y) | 770 // x | left(x) | overlap(x,y) | 771 { 772 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 773 unsigned child1 = buildLattices(tensorExps[e].children.e1, i); 774 BinaryOp binop = cast<BinaryOp>(tensorExps[e].op); 775 Region &leftRegion = binop.leftRegion(); 776 Region &rightRegion = binop.rightRegion(); 777 // Left Region. 778 Operation *leftYield = nullptr; 779 if (!leftRegion.empty()) { 780 Block &leftBlock = leftRegion.front(); 781 leftYield = leftBlock.getTerminator(); 782 } 783 // Right Region. 784 Operation *rightYield = nullptr; 785 if (!rightRegion.empty()) { 786 Block &rightBlock = rightRegion.front(); 787 rightYield = rightBlock.getTerminator(); 788 } 789 bool includeLeft = binop.left_identity() || !leftRegion.empty(); 790 bool includeRight = binop.right_identity() || !rightRegion.empty(); 791 return takeCombi(kBinary, child0, child1, binop, includeLeft, 792 kBinaryBranch, leftYield, includeRight, kBinaryBranch, 793 rightYield); 794 } 795 } 796 llvm_unreachable("unexpected expression kind"); 797 } 798 799 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 800 Operation *yield = op.region().front().getTerminator(); 801 return buildTensorExp(op, yield->getOperand(0)); 802 } 803 804 /// Only returns false if we are certain this is a nonzero. 805 bool Merger::maybeZero(unsigned e) const { 806 if (tensorExps[e].kind == kInvariant) { 807 if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) { 808 ArrayAttr arrayAttr = c.getValue(); 809 return arrayAttr[0].cast<FloatAttr>().getValue().isZero() && 810 arrayAttr[0].cast<FloatAttr>().getValue().isZero(); 811 } 812 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 813 return c.value() == 0; 814 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 815 return c.value().isZero(); 816 } 817 return true; 818 } 819 820 bool Merger::isInvariant(unsigned e) const { 821 return tensorExps[e].kind == kInvariant; 822 } 823 824 Type Merger::inferType(unsigned e, Value src) { 825 // Obtain the destination type from the cast node. 826 Type dtp = tensorExps[e].val.getType(); 827 // Inspect source type. For vector types, apply the same 828 // vectorization to the destination type. 829 if (auto vtp = src.getType().dyn_cast<VectorType>()) 830 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 831 return dtp; 832 } 833 834 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 835 if (auto arg = v.dyn_cast<BlockArgument>()) { 836 unsigned argN = arg.getArgNumber(); 837 // Any argument of the generic op that is not marked as a scalar 838 // argument is considered a tensor, indexed by the implicit loop 839 // bounds. This includes rank-0 tensor arguments. 840 if (arg.getOwner()->getParentOp() == op) { 841 OpOperand *t = op.getInputAndOutputOperands()[argN]; 842 if (!op.isScalar(t)) 843 return addExp(kTensor, argN); 844 v = t->get(); // get scalar value 845 } 846 // Any other argument (marked as scalar argument for the generic op 847 // or belonging to an enveloping op) is considered invariant. 848 return addExp(kInvariant, v); 849 } 850 // Something defined outside is invariant. 851 Operation *def = v.getDefiningOp(); 852 if (def->getBlock() != &op.region().front()) 853 return addExp(kInvariant, v); 854 // Construct index operations. 855 if (def->getNumOperands() == 0) { 856 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 857 return addExp(kIndex, indexOp.dim()); 858 } 859 // Construct unary operations if subexpression can be built. 860 if (def->getNumOperands() == 1) { 861 auto x = buildTensorExp(op, def->getOperand(0)); 862 if (x.hasValue()) { 863 unsigned e = x.getValue(); 864 if (isa<math::AbsOp>(def)) 865 return addExp(kAbsF, e); 866 if (isa<complex::AbsOp>(def)) 867 return addExp(kAbsC, e); 868 if (isa<math::CeilOp>(def)) 869 return addExp(kCeilF, e); 870 if (isa<math::FloorOp>(def)) 871 return addExp(kFloorF, e); 872 if (isa<math::SqrtOp>(def)) 873 return addExp(kSqrtF, e); 874 if (isa<complex::SqrtOp>(def)) 875 return addExp(kSqrtC, e); 876 if (isa<math::ExpM1Op>(def)) 877 return addExp(kExpm1F, e); 878 if (isa<complex::Expm1Op>(def)) 879 return addExp(kExpm1C, e); 880 if (isa<math::Log1pOp>(def)) 881 return addExp(kLog1pF, e); 882 if (isa<complex::Log1pOp>(def)) 883 return addExp(kLog1pC, e); 884 if (isa<math::SinOp>(def)) 885 return addExp(kSinF, e); 886 if (isa<complex::SinOp>(def)) 887 return addExp(kSinC, e); 888 if (isa<math::TanhOp>(def)) 889 return addExp(kTanhF, e); 890 if (isa<complex::TanhOp>(def)) 891 return addExp(kTanhC, e); 892 if (isa<arith::NegFOp>(def)) 893 return addExp(kNegF, e); // no negi in std 894 if (isa<complex::NegOp>(def)) 895 return addExp(kNegC, e); 896 if (isa<arith::TruncFOp>(def)) 897 return addExp(kTruncF, e, v); 898 if (isa<arith::ExtFOp>(def)) 899 return addExp(kExtF, e, v); 900 if (isa<arith::FPToSIOp>(def)) 901 return addExp(kCastFS, e, v); 902 if (isa<arith::FPToUIOp>(def)) 903 return addExp(kCastFU, e, v); 904 if (isa<arith::SIToFPOp>(def)) 905 return addExp(kCastSF, e, v); 906 if (isa<arith::UIToFPOp>(def)) 907 return addExp(kCastUF, e, v); 908 if (isa<arith::ExtSIOp>(def)) 909 return addExp(kCastS, e, v); 910 if (isa<arith::ExtUIOp>(def)) 911 return addExp(kCastU, e, v); 912 if (isa<arith::IndexCastOp>(def)) 913 return addExp(kCastIdx, e, v); 914 if (isa<arith::TruncIOp>(def)) 915 return addExp(kTruncI, e, v); 916 if (isa<complex::ImOp>(def)) 917 return addExp(kCIm, e); 918 if (isa<complex::ReOp>(def)) 919 return addExp(kCRe, e); 920 if (isa<arith::BitcastOp>(def)) 921 return addExp(kBitCast, e, v); 922 if (isa<sparse_tensor::UnaryOp>(def)) 923 return addExp(kUnary, e, Value(), def); 924 } 925 } 926 // Construct binary operations if subexpressions can be built. 927 // See buildLattices() for an explanation of rejecting certain 928 // division and shift operations 929 if (def->getNumOperands() == 2) { 930 auto x = buildTensorExp(op, def->getOperand(0)); 931 auto y = buildTensorExp(op, def->getOperand(1)); 932 if (x.hasValue() && y.hasValue()) { 933 unsigned e0 = x.getValue(); 934 unsigned e1 = y.getValue(); 935 if (isa<arith::MulFOp>(def)) 936 return addExp(kMulF, e0, e1); 937 if (isa<complex::MulOp>(def)) 938 return addExp(kMulC, e0, e1); 939 if (isa<arith::MulIOp>(def)) 940 return addExp(kMulI, e0, e1); 941 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 942 return addExp(kDivF, e0, e1); 943 if (isa<complex::DivOp>(def) && !maybeZero(e1)) 944 return addExp(kDivC, e0, e1); 945 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 946 return addExp(kDivS, e0, e1); 947 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 948 return addExp(kDivU, e0, e1); 949 if (isa<arith::AddFOp>(def)) 950 return addExp(kAddF, e0, e1); 951 if (isa<complex::AddOp>(def)) 952 return addExp(kAddC, e0, e1); 953 if (isa<arith::AddIOp>(def)) 954 return addExp(kAddI, e0, e1); 955 if (isa<arith::SubFOp>(def)) 956 return addExp(kSubF, e0, e1); 957 if (isa<complex::SubOp>(def)) 958 return addExp(kSubC, e0, e1); 959 if (isa<arith::SubIOp>(def)) 960 return addExp(kSubI, e0, e1); 961 if (isa<arith::AndIOp>(def)) 962 return addExp(kAndI, e0, e1); 963 if (isa<arith::OrIOp>(def)) 964 return addExp(kOrI, e0, e1); 965 if (isa<arith::XOrIOp>(def)) 966 return addExp(kXorI, e0, e1); 967 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 968 return addExp(kShrS, e0, e1); 969 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 970 return addExp(kShrU, e0, e1); 971 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 972 return addExp(kShlI, e0, e1); 973 if (isa<sparse_tensor::BinaryOp>(def)) 974 return addExp(kBinary, e0, e1, Value(), def); 975 } 976 } 977 // Cannot build. 978 return None; 979 } 980 981 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 982 ValueRange vals) { 983 // Make a clone of overlap region. 984 Region tmpRegion; 985 BlockAndValueMapping mapper; 986 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 987 Block &clonedBlock = tmpRegion.front(); 988 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 989 // Merge cloned block and return yield value. 990 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 991 rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); 992 Value val = clonedYield.result(); 993 rewriter.eraseOp(clonedYield); 994 rewriter.eraseOp(placeholder); 995 return val; 996 } 997 998 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 999 Operation *op, Value v0) { 1000 if (!v0) 1001 // Empty input value must be propagated. 1002 return Value(); 1003 UnaryOp unop = cast<UnaryOp>(op); 1004 Region &presentRegion = unop.presentRegion(); 1005 if (presentRegion.empty()) 1006 // Uninitialized Value() will be interpreted as missing data in the 1007 // output. 1008 return Value(); 1009 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 1010 } 1011 1012 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 1013 Operation *op, Value v0, Value v1) { 1014 if (!v0 || !v1) 1015 // Empty input values must be propagated. 1016 return Value(); 1017 BinaryOp binop = cast<BinaryOp>(op); 1018 Region &overlapRegion = binop.overlapRegion(); 1019 if (overlapRegion.empty()) 1020 // Uninitialized Value() will be interpreted as missing data in the 1021 // output. 1022 return Value(); 1023 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 1024 } 1025 1026 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, 1027 Value v0, Value v1) { 1028 switch (tensorExps[e].kind) { 1029 // Leaf. 1030 case kTensor: 1031 case kInvariant: 1032 case kIndex: 1033 llvm_unreachable("unexpected non-op"); 1034 // Unary operations. 1035 case kAbsF: 1036 return rewriter.create<math::AbsOp>(loc, v0); 1037 case kAbsC: { 1038 auto type = v0.getType().cast<ComplexType>(); 1039 auto eltType = type.getElementType().cast<FloatType>(); 1040 return rewriter.create<complex::AbsOp>(loc, eltType, v0); 1041 } 1042 case kCeilF: 1043 return rewriter.create<math::CeilOp>(loc, v0); 1044 case kFloorF: 1045 return rewriter.create<math::FloorOp>(loc, v0); 1046 case kSqrtF: 1047 return rewriter.create<math::SqrtOp>(loc, v0); 1048 case kSqrtC: 1049 return rewriter.create<complex::SqrtOp>(loc, v0); 1050 case kExpm1F: 1051 return rewriter.create<math::ExpM1Op>(loc, v0); 1052 case kExpm1C: 1053 return rewriter.create<complex::Expm1Op>(loc, v0); 1054 case kLog1pF: 1055 return rewriter.create<math::Log1pOp>(loc, v0); 1056 case kLog1pC: 1057 return rewriter.create<complex::Log1pOp>(loc, v0); 1058 case kSinF: 1059 return rewriter.create<math::SinOp>(loc, v0); 1060 case kSinC: 1061 return rewriter.create<complex::SinOp>(loc, v0); 1062 case kTanhF: 1063 return rewriter.create<math::TanhOp>(loc, v0); 1064 case kTanhC: 1065 return rewriter.create<complex::TanhOp>(loc, v0); 1066 case kNegF: 1067 return rewriter.create<arith::NegFOp>(loc, v0); 1068 case kNegC: 1069 return rewriter.create<complex::NegOp>(loc, v0); 1070 case kNegI: // no negi in std 1071 return rewriter.create<arith::SubIOp>( 1072 loc, 1073 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 1074 rewriter.getZeroAttr(v0.getType())), 1075 v0); 1076 case kTruncF: 1077 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 1078 case kExtF: 1079 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 1080 case kCastFS: 1081 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 1082 case kCastFU: 1083 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 1084 case kCastSF: 1085 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 1086 case kCastUF: 1087 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 1088 case kCastS: 1089 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 1090 case kCastU: 1091 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 1092 case kCastIdx: 1093 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 1094 case kTruncI: 1095 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 1096 case kCIm: { 1097 auto type = v0.getType().cast<ComplexType>(); 1098 auto eltType = type.getElementType().cast<FloatType>(); 1099 return rewriter.create<complex::ImOp>(loc, eltType, v0); 1100 } 1101 case kCRe: { 1102 auto type = v0.getType().cast<ComplexType>(); 1103 auto eltType = type.getElementType().cast<FloatType>(); 1104 return rewriter.create<complex::ReOp>(loc, eltType, v0); 1105 } 1106 case kBitCast: 1107 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 1108 // Binary operations. 1109 case kMulF: 1110 return rewriter.create<arith::MulFOp>(loc, v0, v1); 1111 case kMulC: 1112 return rewriter.create<complex::MulOp>(loc, v0, v1); 1113 case kMulI: 1114 return rewriter.create<arith::MulIOp>(loc, v0, v1); 1115 case kDivF: 1116 return rewriter.create<arith::DivFOp>(loc, v0, v1); 1117 case kDivC: 1118 return rewriter.create<complex::DivOp>(loc, v0, v1); 1119 case kDivS: 1120 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 1121 case kDivU: 1122 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 1123 case kAddF: 1124 return rewriter.create<arith::AddFOp>(loc, v0, v1); 1125 case kAddC: 1126 return rewriter.create<complex::AddOp>(loc, v0, v1); 1127 case kAddI: 1128 return rewriter.create<arith::AddIOp>(loc, v0, v1); 1129 case kSubF: 1130 return rewriter.create<arith::SubFOp>(loc, v0, v1); 1131 case kSubC: 1132 return rewriter.create<complex::SubOp>(loc, v0, v1); 1133 case kSubI: 1134 return rewriter.create<arith::SubIOp>(loc, v0, v1); 1135 case kAndI: 1136 return rewriter.create<arith::AndIOp>(loc, v0, v1); 1137 case kOrI: 1138 return rewriter.create<arith::OrIOp>(loc, v0, v1); 1139 case kXorI: 1140 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 1141 case kShrS: 1142 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 1143 case kShrU: 1144 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 1145 case kShlI: 1146 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 1147 case kBinaryBranch: // semi-ring ops with custom logic. 1148 return insertYieldOp(rewriter, loc, 1149 *tensorExps[e].op->getBlock()->getParent(), {v0}); 1150 case kUnary: 1151 return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); 1152 case kBinary: 1153 return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); 1154 } 1155 llvm_unreachable("unexpected expression kind in build"); 1156 } 1157 1158 } // namespace sparse_tensor 1159 } // namespace mlir 1160