1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 18 namespace mlir { 19 namespace sparse_tensor { 20 21 //===----------------------------------------------------------------------===// 22 // Constructors. 23 //===----------------------------------------------------------------------===// 24 25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) 26 : kind(k), val(v), op(o) { 27 switch (kind) { 28 // Leaf. 29 case kTensor: 30 assert(x != -1u && y == -1u && !v && !o); 31 tensor = x; 32 break; 33 case kInvariant: 34 assert(x == -1u && y == -1u && v && !o); 35 break; 36 case kIndex: 37 assert(x != -1u && y == -1u && !v && !o); 38 index = x; 39 break; 40 // Unary operations. 41 case kAbsF: 42 case kAbsC: 43 case kCeilF: 44 case kFloorF: 45 case kSqrtF: 46 case kSqrtC: 47 case kExpm1F: 48 case kExpm1C: 49 case kLog1pF: 50 case kLog1pC: 51 case kSinF: 52 case kSinC: 53 case kTanhF: 54 case kTanhC: 55 case kNegF: 56 case kNegC: 57 case kNegI: 58 case kCIm: 59 case kCRe: 60 assert(x != -1u && y == -1u && !v && !o); 61 children.e0 = x; 62 children.e1 = y; 63 break; 64 case kTruncF: 65 case kExtF: 66 case kCastFS: 67 case kCastFU: 68 case kCastSF: 69 case kCastUF: 70 case kCastS: 71 case kCastU: 72 case kCastIdx: 73 case kTruncI: 74 case kBitCast: 75 assert(x != -1u && y == -1u && v && !o); 76 children.e0 = x; 77 children.e1 = y; 78 break; 79 case kBinaryBranch: 80 assert(x != -1u && y == -1u && !v && o); 81 children.e0 = x; 82 children.e1 = y; 83 break; 84 case kUnary: 85 // No assertion on y can be made, as the branching paths involve both 86 // a unary (mapSet) and binary (takeDisj) pathway. 87 assert(x != -1u && !v && o); 88 children.e0 = x; 89 children.e1 = y; 90 break; 91 // Binary operations. 92 case kMulF: 93 case kMulC: 94 case kMulI: 95 case kDivF: 96 case kDivC: 97 case kDivS: 98 case kDivU: 99 case kAddF: 100 case kAddC: 101 case kAddI: 102 case kSubF: 103 case kSubC: 104 case kSubI: 105 case kAndI: 106 case kOrI: 107 case kXorI: 108 case kShrS: 109 case kShrU: 110 case kShlI: 111 assert(x != -1u && y != -1u && !v && !o); 112 children.e0 = x; 113 children.e1 = y; 114 break; 115 case kBinary: 116 assert(x != -1u && y != -1u && !v && o); 117 children.e0 = x; 118 children.e1 = y; 119 break; 120 } 121 } 122 123 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 124 : bits(n, false), simple(), exp(e) { 125 bits.set(b); 126 } 127 128 LatPoint::LatPoint(const BitVector &b, unsigned e) 129 : bits(b), simple(), exp(e) {} 130 131 //===----------------------------------------------------------------------===// 132 // Lattice methods. 133 //===----------------------------------------------------------------------===// 134 135 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, 136 Operation *op) { 137 unsigned e = tensorExps.size(); 138 tensorExps.push_back(TensorExp(k, e0, e1, v, op)); 139 return e; 140 } 141 142 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 143 assert(t < numTensors && i < numLoops); 144 unsigned p = latPoints.size(); 145 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 146 return p; 147 } 148 149 unsigned Merger::addSet() { 150 unsigned s = latSets.size(); 151 latSets.emplace_back(SmallVector<unsigned, 16>()); 152 return s; 153 } 154 155 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, 156 Operation *op) { 157 unsigned p = latPoints.size(); 158 BitVector nb = BitVector(latPoints[p0].bits); 159 nb |= latPoints[p1].bits; 160 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); 161 latPoints.push_back(LatPoint(nb, e)); 162 return p; 163 } 164 165 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 166 unsigned s = addSet(); 167 for (unsigned p0 : latSets[s0]) 168 for (unsigned p1 : latSets[s1]) 169 latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); 170 return s; 171 } 172 173 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 174 unsigned s = takeConj(kind, s0, s1, op); 175 // Followed by all in s0. 176 for (unsigned p : latSets[s0]) 177 latSets[s].push_back(p); 178 // Map binary 0-y to unary -y. 179 // TODO: move this if-else logic into buildLattices 180 if (kind == kSubF) 181 s1 = mapSet(kNegF, s1); 182 else if (kind == kSubC) 183 s1 = mapSet(kNegC, s1); 184 else if (kind == kSubI) 185 s1 = mapSet(kNegI, s1); 186 // Followed by all in s1. 187 for (unsigned p : latSets[s1]) 188 latSets[s].push_back(p); 189 return s; 190 } 191 192 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, 193 bool includeLeft, Kind ltrans, Operation *opleft, 194 bool includeRight, Kind rtrans, Operation *opright) { 195 unsigned s = takeConj(kind, s0, s1, orig); 196 // Left Region. 197 if (includeLeft) { 198 if (opleft) 199 s0 = mapSet(ltrans, s0, Value(), opleft); 200 for (unsigned p : latSets[s0]) 201 latSets[s].push_back(p); 202 } 203 // Right Region. 204 if (includeRight) { 205 if (opright) 206 s1 = mapSet(rtrans, s1, Value(), opright); 207 for (unsigned p : latSets[s1]) 208 latSets[s].push_back(p); 209 } 210 return s; 211 } 212 213 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { 214 assert(kAbsF <= kind && kind <= kUnary); 215 unsigned s = addSet(); 216 for (unsigned p : latSets[s0]) { 217 unsigned e = addExp(kind, latPoints[p].exp, v, op); 218 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 219 latSets[s].push_back(latPoints.size() - 1); 220 } 221 return s; 222 } 223 224 unsigned Merger::optimizeSet(unsigned s0) { 225 unsigned s = addSet(); 226 assert(!latSets[s0].empty()); 227 unsigned p0 = latSets[s0][0]; 228 for (unsigned p1 : latSets[s0]) { 229 bool add = true; 230 if (p0 != p1) { 231 // Is this a straightforward copy? 232 unsigned e = latPoints[p1].exp; 233 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 234 continue; 235 // Conjunction already covered? 236 for (unsigned p2 : latSets[s]) { 237 assert(!latGT(p1, p2)); // Lj => Li would be bad 238 if (onlyDenseDiff(p2, p1)) { 239 add = false; 240 break; 241 } 242 } 243 assert(!add || latGT(p0, p1)); 244 } 245 if (add) 246 latSets[s].push_back(p1); 247 } 248 for (unsigned p : latSets[s]) 249 latPoints[p].simple = simplifyCond(s, p); 250 return s; 251 } 252 253 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 254 // First determine if this lattice point is a *singleton*, i.e., 255 // the last point in a lattice, no other is less than this one. 256 bool isSingleton = true; 257 for (unsigned p1 : latSets[s0]) { 258 if (p0 != p1 && latGT(p0, p1)) { 259 isSingleton = false; 260 break; 261 } 262 } 263 // Now apply the two basic rules. 264 BitVector simple = latPoints[p0].bits; 265 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 266 for (unsigned b = 0, be = simple.size(); b < be; b++) { 267 if (simple[b] && !isDim(b, kSparse)) { 268 if (reset) 269 simple.reset(b); 270 reset = true; 271 } 272 } 273 return simple; 274 } 275 276 bool Merger::latGT(unsigned i, unsigned j) const { 277 const BitVector &bitsi = latPoints[i].bits; 278 const BitVector &bitsj = latPoints[j].bits; 279 assert(bitsi.size() == bitsj.size()); 280 if (bitsi.count() > bitsj.count()) { 281 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 282 if (bitsj[b] && !bitsi[b]) 283 return false; 284 return true; 285 } 286 return false; 287 } 288 289 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 290 BitVector tmp = latPoints[j].bits; 291 tmp ^= latPoints[i].bits; 292 return !hasAnyDimOf(tmp, kSparse); 293 } 294 295 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { 296 for (unsigned b = 0, be = bits.size(); b < be; b++) 297 if (bits[b] && isDim(b, d)) 298 return true; 299 return false; 300 } 301 302 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 303 switch (tensorExps[e].kind) { 304 // Leaf. 305 case kTensor: 306 return tensorExps[e].tensor == t; 307 case kInvariant: 308 case kIndex: 309 return false; 310 // Unary operations. 311 case kAbsF: 312 case kAbsC: 313 case kCeilF: 314 case kFloorF: 315 case kSqrtF: 316 case kSqrtC: 317 case kExpm1F: 318 case kExpm1C: 319 case kLog1pF: 320 case kLog1pC: 321 case kSinF: 322 case kSinC: 323 case kTanhF: 324 case kTanhC: 325 case kNegF: 326 case kNegC: 327 case kNegI: 328 case kTruncF: 329 case kExtF: 330 case kCastFS: 331 case kCastFU: 332 case kCastSF: 333 case kCastUF: 334 case kCastS: 335 case kCastU: 336 case kCastIdx: 337 case kTruncI: 338 case kCIm: 339 case kCRe: 340 case kBitCast: 341 return isSingleCondition(t, tensorExps[e].children.e0); 342 case kBinaryBranch: 343 case kUnary: 344 return false; 345 // Binary operations. 346 case kDivF: // note: x / c only 347 case kDivC: 348 case kDivS: 349 case kDivU: 350 assert(!maybeZero(tensorExps[e].children.e1)); 351 return isSingleCondition(t, tensorExps[e].children.e0); 352 case kShrS: // note: x >> inv only 353 case kShrU: 354 case kShlI: 355 assert(isInvariant(tensorExps[e].children.e1)); 356 return isSingleCondition(t, tensorExps[e].children.e0); 357 case kMulF: 358 case kMulC: 359 case kMulI: 360 case kAndI: 361 if (isSingleCondition(t, tensorExps[e].children.e0)) 362 return isSingleCondition(t, tensorExps[e].children.e1) || 363 isInvariant(tensorExps[e].children.e1); 364 if (isSingleCondition(t, tensorExps[e].children.e1)) 365 return isInvariant(tensorExps[e].children.e0); 366 return false; 367 case kAddF: 368 case kAddC: 369 case kAddI: 370 return isSingleCondition(t, tensorExps[e].children.e0) && 371 isSingleCondition(t, tensorExps[e].children.e1); 372 case kSubF: 373 case kSubC: 374 case kSubI: 375 case kOrI: 376 case kXorI: 377 case kBinary: 378 return false; 379 } 380 llvm_unreachable("unexpected kind"); 381 } 382 383 #ifndef NDEBUG 384 385 //===----------------------------------------------------------------------===// 386 // Print methods (for debugging). 387 //===----------------------------------------------------------------------===// 388 389 static const char *kindToOpSymbol(Kind kind) { 390 switch (kind) { 391 // Leaf. 392 case kTensor: 393 return "tensor"; 394 case kInvariant: 395 return "invariant"; 396 case kIndex: 397 return "index"; 398 // Unary operations. 399 case kAbsF: 400 case kAbsC: 401 return "abs"; 402 case kCeilF: 403 return "ceil"; 404 case kFloorF: 405 return "floor"; 406 case kSqrtF: 407 case kSqrtC: 408 return "sqrt"; 409 case kExpm1F: 410 case kExpm1C: 411 return "expm1"; 412 case kLog1pF: 413 case kLog1pC: 414 return "log1p"; 415 case kSinF: 416 case kSinC: 417 return "sin"; 418 case kTanhF: 419 case kTanhC: 420 return "tanh"; 421 case kNegF: 422 case kNegC: 423 case kNegI: 424 return "-"; 425 case kTruncF: 426 case kExtF: 427 case kCastFS: 428 case kCastFU: 429 case kCastSF: 430 case kCastUF: 431 case kCastS: 432 case kCastU: 433 case kCastIdx: 434 case kTruncI: 435 case kCIm: 436 return "complex.im"; 437 case kCRe: 438 return "complex.re"; 439 case kBitCast: 440 return "cast"; 441 case kBinaryBranch: 442 return "binary_branch"; 443 case kUnary: 444 return "unary"; 445 // Binary operations. 446 case kMulF: 447 case kMulC: 448 case kMulI: 449 return "*"; 450 case kDivF: 451 case kDivC: 452 case kDivS: 453 case kDivU: 454 return "/"; 455 case kAddF: 456 case kAddC: 457 case kAddI: 458 return "+"; 459 case kSubF: 460 case kSubC: 461 case kSubI: 462 return "-"; 463 case kAndI: 464 return "&"; 465 case kOrI: 466 return "|"; 467 case kXorI: 468 return "^"; 469 case kShrS: 470 return "a>>"; 471 case kShrU: 472 return ">>"; 473 case kShlI: 474 return "<<"; 475 case kBinary: 476 return "binary"; 477 } 478 llvm_unreachable("unexpected kind for symbol"); 479 } 480 481 void Merger::dumpExp(unsigned e) const { 482 switch (tensorExps[e].kind) { 483 // Leaf. 484 case kTensor: 485 if (tensorExps[e].tensor == syntheticTensor) 486 llvm::dbgs() << "synthetic_"; 487 else if (tensorExps[e].tensor == outTensor) 488 llvm::dbgs() << "output_"; 489 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 490 break; 491 case kInvariant: 492 llvm::dbgs() << "invariant"; 493 break; 494 case kIndex: 495 llvm::dbgs() << "index_" << tensorExps[e].index; 496 break; 497 // Unary operations. 498 case kAbsF: 499 case kAbsC: 500 case kCeilF: 501 case kFloorF: 502 case kSqrtF: 503 case kSqrtC: 504 case kExpm1F: 505 case kExpm1C: 506 case kLog1pF: 507 case kLog1pC: 508 case kSinF: 509 case kSinC: 510 case kTanhF: 511 case kTanhC: 512 case kNegF: 513 case kNegC: 514 case kNegI: 515 case kTruncF: 516 case kExtF: 517 case kCastFS: 518 case kCastFU: 519 case kCastSF: 520 case kCastUF: 521 case kCastS: 522 case kCastU: 523 case kCastIdx: 524 case kTruncI: 525 case kCIm: 526 case kCRe: 527 case kBitCast: 528 case kBinaryBranch: 529 case kUnary: 530 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 531 dumpExp(tensorExps[e].children.e0); 532 break; 533 // Binary operations. 534 case kMulF: 535 case kMulC: 536 case kMulI: 537 case kDivF: 538 case kDivC: 539 case kDivS: 540 case kDivU: 541 case kAddF: 542 case kAddC: 543 case kAddI: 544 case kSubF: 545 case kSubC: 546 case kSubI: 547 case kAndI: 548 case kOrI: 549 case kXorI: 550 case kShrS: 551 case kShrU: 552 case kShlI: 553 case kBinary: 554 llvm::dbgs() << "("; 555 dumpExp(tensorExps[e].children.e0); 556 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 557 dumpExp(tensorExps[e].children.e1); 558 llvm::dbgs() << ")"; 559 } 560 } 561 562 void Merger::dumpLat(unsigned p) const { 563 llvm::dbgs() << "lat("; 564 dumpBits(latPoints[p].bits); 565 llvm::dbgs() << " :"; 566 dumpBits(latPoints[p].simple); 567 llvm::dbgs() << " : "; 568 dumpExp(latPoints[p].exp); 569 llvm::dbgs() << " )\n"; 570 } 571 572 void Merger::dumpSet(unsigned s) const { 573 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 574 for (unsigned p : latSets[s]) { 575 llvm::dbgs() << " "; 576 dumpLat(p); 577 } 578 llvm::dbgs() << "}\n"; 579 } 580 581 void Merger::dumpBits(const BitVector &bits) const { 582 for (unsigned b = 0, be = bits.size(); b < be; b++) { 583 if (bits[b]) { 584 unsigned t = tensor(b); 585 unsigned i = index(b); 586 llvm::dbgs() << " i_" << t << "_" << i << "_"; 587 switch (dims[t][i]) { 588 case kSparse: 589 llvm::dbgs() << "S"; 590 break; 591 case kDense: 592 llvm::dbgs() << "D"; 593 break; 594 case kSingle: 595 llvm::dbgs() << "T"; 596 break; 597 case kUndef: 598 llvm::dbgs() << "U"; 599 break; 600 } 601 } 602 } 603 } 604 605 #endif // NDEBUG 606 607 //===----------------------------------------------------------------------===// 608 // Builder methods. 609 //===----------------------------------------------------------------------===// 610 611 unsigned Merger::buildLattices(unsigned e, unsigned i) { 612 Kind kind = tensorExps[e].kind; 613 switch (kind) { 614 // Leaf. 615 case kTensor: 616 case kInvariant: 617 case kIndex: { 618 // Either the index is really used in the tensor expression, or it is 619 // set to the undefined index in that dimension. An invariant expression, 620 // a proper index value, and a truly dynamic sparse output tensor are set 621 // to a synthetic tensor with undefined indices only to ensure the 622 // iteration space is not skipped as a result of their contents. 623 unsigned s = addSet(); 624 unsigned t = syntheticTensor; 625 if (kind == kTensor) { 626 t = tensorExps[e].tensor; 627 if (hasSparseOut && t == outTensor) 628 t = syntheticTensor; 629 } 630 latSets[s].push_back(addLat(t, i, e)); 631 return s; 632 } 633 // Unary operations. 634 case kAbsF: 635 case kAbsC: 636 case kCeilF: 637 case kFloorF: 638 case kSqrtF: 639 case kSqrtC: 640 case kExpm1F: 641 case kExpm1C: 642 case kLog1pF: 643 case kLog1pC: 644 case kSinF: 645 case kSinC: 646 case kTanhF: 647 case kTanhC: 648 case kNegF: 649 case kNegC: 650 case kNegI: 651 case kTruncF: 652 case kExtF: 653 case kCastFS: 654 case kCastFU: 655 case kCastSF: 656 case kCastUF: 657 case kCastS: 658 case kCastU: 659 case kCastIdx: 660 case kTruncI: 661 case kCIm: 662 case kCRe: 663 case kBitCast: 664 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 665 // lattice set of the operand through the operator into a new set. 666 // 667 // -y|!y | y | 668 // --+---+---+ 669 // | 0 |-y | 670 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 671 tensorExps[e].val); 672 case kBinaryBranch: 673 // The left or right half of a binary operation which has already 674 // been split into separate operations for each region. 675 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), 676 tensorExps[e].op); 677 case kUnary: 678 // A custom unary operation. 679 // 680 // op y| !y | y | 681 // ----+----------+------------+ 682 // | absent() | present(y) | 683 { 684 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 685 UnaryOp unop = cast<UnaryOp>(tensorExps[e].op); 686 Region &absentRegion = unop.getAbsentRegion(); 687 688 if (absentRegion.empty()) { 689 // Simple mapping over existing values. 690 return mapSet(kind, child0, Value(), unop); 691 } // Use a disjunction with `unop` on the left and the absent value as an 692 // invariant on the right. 693 Block &absentBlock = absentRegion.front(); 694 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 695 Value absentVal = absentYield.getResult(); 696 unsigned rhs = addExp(kInvariant, absentVal); 697 return takeDisj(kind, child0, buildLattices(rhs, i), unop); 698 } 699 // Binary operations. 700 case kMulF: 701 case kMulC: 702 case kMulI: 703 case kAndI: 704 // A multiplicative operation only needs to be performed 705 // for the conjunction of sparse iteration spaces. 706 // 707 // x*y|!y | y | 708 // ---+---+---+ 709 // !x | 0 | 0 | 710 // x | 0 |x*y| 711 // 712 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 713 return takeConj(kind, // take binary conjunction 714 buildLattices(tensorExps[e].children.e0, i), 715 buildLattices(tensorExps[e].children.e1, i)); 716 case kDivF: 717 case kDivC: 718 case kDivS: 719 case kDivU: 720 // A division is tricky, since 0/0, 0/c, c/0 all have 721 // specific outcomes for floating-point and integers. 722 // Thus, we need to traverse the full iteration space. 723 // 724 // x/y|!y | y | 725 // ---+---+---+ 726 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 727 // x |x/0|x/y| INT: x/0=exception for any x 728 // 729 // TODO: for now we "fixed" this by only accepting x/c cases 730 // during expression building, so that the conjunction 731 // rules applies (viz. x/c = x*(1/c) as far as lattice 732 // construction is concerned). 733 assert(!maybeZero(tensorExps[e].children.e1)); 734 return takeConj(kind, // take binary conjunction 735 buildLattices(tensorExps[e].children.e0, i), 736 buildLattices(tensorExps[e].children.e1, i)); 737 case kAddF: 738 case kAddC: 739 case kAddI: 740 case kSubF: 741 case kSubC: 742 case kSubI: 743 case kOrI: 744 case kXorI: 745 // An additive operation needs to be performed 746 // for the disjunction of sparse iteration spaces. 747 // 748 // x+y|!y | y | x-y|!y | y | 749 // ---+---+---+ ---+---+---+ 750 // !x | 0 | y | !x | 0 |-y | 751 // x | x |x+y| x | x |x-y| 752 return takeDisj(kind, // take binary disjunction 753 buildLattices(tensorExps[e].children.e0, i), 754 buildLattices(tensorExps[e].children.e1, i)); 755 case kShrS: 756 case kShrU: 757 case kShlI: 758 // A shift operation by an invariant amount (viz. tensor expressions 759 // can only occur at the left-hand-side of the operator) can be handled 760 // with the conjuction rule. 761 assert(isInvariant(tensorExps[e].children.e1)); 762 return takeConj(kind, // take binary conjunction 763 buildLattices(tensorExps[e].children.e0, i), 764 buildLattices(tensorExps[e].children.e1, i)); 765 case kBinary: 766 // A custom binary operation. 767 // 768 // x op y| !y | y | 769 // ------+---------+--------------+ 770 // !x | empty | right(y) | 771 // x | left(x) | overlap(x,y) | 772 { 773 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 774 unsigned child1 = buildLattices(tensorExps[e].children.e1, i); 775 BinaryOp binop = cast<BinaryOp>(tensorExps[e].op); 776 Region &leftRegion = binop.getLeftRegion(); 777 Region &rightRegion = binop.getRightRegion(); 778 // Left Region. 779 Operation *leftYield = nullptr; 780 if (!leftRegion.empty()) { 781 Block &leftBlock = leftRegion.front(); 782 leftYield = leftBlock.getTerminator(); 783 } 784 // Right Region. 785 Operation *rightYield = nullptr; 786 if (!rightRegion.empty()) { 787 Block &rightBlock = rightRegion.front(); 788 rightYield = rightBlock.getTerminator(); 789 } 790 bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty(); 791 bool includeRight = binop.getRightIdentity() || !rightRegion.empty(); 792 return takeCombi(kBinary, child0, child1, binop, includeLeft, 793 kBinaryBranch, leftYield, includeRight, kBinaryBranch, 794 rightYield); 795 } 796 } 797 llvm_unreachable("unexpected expression kind"); 798 } 799 800 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 801 // Build the linalg semantics backward from yield. 802 Operation *yield = op.region().front().getTerminator(); 803 assert(isa<linalg::YieldOp>(yield)); 804 return buildTensorExp(op, yield->getOperand(0)); 805 } 806 807 /// Only returns false if we are certain this is a nonzero. 808 bool Merger::maybeZero(unsigned e) const { 809 if (tensorExps[e].kind == kInvariant) { 810 if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) { 811 ArrayAttr arrayAttr = c.getValue(); 812 return arrayAttr[0].cast<FloatAttr>().getValue().isZero() && 813 arrayAttr[0].cast<FloatAttr>().getValue().isZero(); 814 } 815 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 816 return c.value() == 0; 817 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 818 return c.value().isZero(); 819 } 820 return true; 821 } 822 823 bool Merger::isInvariant(unsigned e) const { 824 return tensorExps[e].kind == kInvariant; 825 } 826 827 Type Merger::inferType(unsigned e, Value src) { 828 // Obtain the destination type from the cast node. 829 Type dtp = tensorExps[e].val.getType(); 830 // Inspect source type. For vector types, apply the same 831 // vectorization to the destination type. 832 if (auto vtp = src.getType().dyn_cast<VectorType>()) 833 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 834 return dtp; 835 } 836 837 /// Ensures that sparse compiler can generate code for expression. 838 static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) { 839 // Arguments are always admissable. 840 if (auto arg = v.dyn_cast<BlockArgument>()) 841 return true; 842 // Accept index anywhere. 843 Operation *def = v.getDefiningOp(); 844 if (isa<linalg::IndexOp>(def)) 845 return true; 846 // Operation defined outside branch. 847 if (def->getBlock() != block) { 848 return def->getBlock() != op->getBlock(); // invariant? 849 } 850 // Operation defined within branch. Anything is accepted, 851 // as long as all subexpressions are admissable. 852 for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) 853 if (!isAdmissableBranchExp(op, block, def->getOperand(i))) 854 return false; 855 return true; 856 } 857 858 /// Ensures that sparse compiler can generate code for branch. 859 static bool isAdmissableBranch(Operation *op, Region ®ion) { 860 if (region.empty()) 861 return true; 862 // Build the semi-ring branch semantics backward from yield. 863 Operation *yield = region.front().getTerminator(); 864 assert(isa<YieldOp>(yield)); 865 return isAdmissableBranchExp(op, ®ion.front(), yield->getOperand(0)); 866 } 867 868 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 869 if (auto arg = v.dyn_cast<BlockArgument>()) { 870 unsigned argN = arg.getArgNumber(); 871 // Any argument of the generic op that is not marked as a scalar 872 // argument is considered a tensor, indexed by the implicit loop 873 // bounds. This includes rank-0 tensor arguments. 874 if (arg.getOwner()->getParentOp() == op) { 875 OpOperand *t = op.getInputAndOutputOperands()[argN]; 876 if (!op.isScalar(t)) 877 return addExp(kTensor, argN); 878 v = t->get(); // get scalar value 879 } 880 // Any other argument (marked as scalar argument for the generic op 881 // or belonging to an enveloping op) is considered invariant. 882 return addExp(kInvariant, v); 883 } 884 // Something defined outside is invariant. 885 Operation *def = v.getDefiningOp(); 886 if (def->getBlock() != &op.region().front()) 887 return addExp(kInvariant, v); 888 // Construct index operations. 889 if (def->getNumOperands() == 0) { 890 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 891 return addExp(kIndex, indexOp.dim()); 892 } 893 // Construct unary operations if subexpression can be built. 894 if (def->getNumOperands() == 1) { 895 auto x = buildTensorExp(op, def->getOperand(0)); 896 if (x.has_value()) { 897 unsigned e = x.value(); 898 if (isa<math::AbsOp>(def)) 899 return addExp(kAbsF, e); 900 if (isa<complex::AbsOp>(def)) 901 return addExp(kAbsC, e); 902 if (isa<math::CeilOp>(def)) 903 return addExp(kCeilF, e); 904 if (isa<math::FloorOp>(def)) 905 return addExp(kFloorF, e); 906 if (isa<math::SqrtOp>(def)) 907 return addExp(kSqrtF, e); 908 if (isa<complex::SqrtOp>(def)) 909 return addExp(kSqrtC, e); 910 if (isa<math::ExpM1Op>(def)) 911 return addExp(kExpm1F, e); 912 if (isa<complex::Expm1Op>(def)) 913 return addExp(kExpm1C, e); 914 if (isa<math::Log1pOp>(def)) 915 return addExp(kLog1pF, e); 916 if (isa<complex::Log1pOp>(def)) 917 return addExp(kLog1pC, e); 918 if (isa<math::SinOp>(def)) 919 return addExp(kSinF, e); 920 if (isa<complex::SinOp>(def)) 921 return addExp(kSinC, e); 922 if (isa<math::TanhOp>(def)) 923 return addExp(kTanhF, e); 924 if (isa<complex::TanhOp>(def)) 925 return addExp(kTanhC, e); 926 if (isa<arith::NegFOp>(def)) 927 return addExp(kNegF, e); // no negi in std 928 if (isa<complex::NegOp>(def)) 929 return addExp(kNegC, e); 930 if (isa<arith::TruncFOp>(def)) 931 return addExp(kTruncF, e, v); 932 if (isa<arith::ExtFOp>(def)) 933 return addExp(kExtF, e, v); 934 if (isa<arith::FPToSIOp>(def)) 935 return addExp(kCastFS, e, v); 936 if (isa<arith::FPToUIOp>(def)) 937 return addExp(kCastFU, e, v); 938 if (isa<arith::SIToFPOp>(def)) 939 return addExp(kCastSF, e, v); 940 if (isa<arith::UIToFPOp>(def)) 941 return addExp(kCastUF, e, v); 942 if (isa<arith::ExtSIOp>(def)) 943 return addExp(kCastS, e, v); 944 if (isa<arith::ExtUIOp>(def)) 945 return addExp(kCastU, e, v); 946 if (isa<arith::IndexCastOp>(def)) 947 return addExp(kCastIdx, e, v); 948 if (isa<arith::TruncIOp>(def)) 949 return addExp(kTruncI, e, v); 950 if (isa<complex::ImOp>(def)) 951 return addExp(kCIm, e); 952 if (isa<complex::ReOp>(def)) 953 return addExp(kCRe, e); 954 if (isa<arith::BitcastOp>(def)) 955 return addExp(kBitCast, e, v); 956 if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) { 957 if (isAdmissableBranch(unop, unop.getPresentRegion()) && 958 isAdmissableBranch(unop, unop.getAbsentRegion())) 959 return addExp(kUnary, e, Value(), def); 960 } 961 } 962 } 963 // Construct binary operations if subexpressions can be built. 964 // See buildLattices() for an explanation of rejecting certain 965 // division and shift operations 966 if (def->getNumOperands() == 2) { 967 auto x = buildTensorExp(op, def->getOperand(0)); 968 auto y = buildTensorExp(op, def->getOperand(1)); 969 if (x.has_value() && y.has_value()) { 970 unsigned e0 = x.value(); 971 unsigned e1 = y.value(); 972 if (isa<arith::MulFOp>(def)) 973 return addExp(kMulF, e0, e1); 974 if (isa<complex::MulOp>(def)) 975 return addExp(kMulC, e0, e1); 976 if (isa<arith::MulIOp>(def)) 977 return addExp(kMulI, e0, e1); 978 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 979 return addExp(kDivF, e0, e1); 980 if (isa<complex::DivOp>(def) && !maybeZero(e1)) 981 return addExp(kDivC, e0, e1); 982 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 983 return addExp(kDivS, e0, e1); 984 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 985 return addExp(kDivU, e0, e1); 986 if (isa<arith::AddFOp>(def)) 987 return addExp(kAddF, e0, e1); 988 if (isa<complex::AddOp>(def)) 989 return addExp(kAddC, e0, e1); 990 if (isa<arith::AddIOp>(def)) 991 return addExp(kAddI, e0, e1); 992 if (isa<arith::SubFOp>(def)) 993 return addExp(kSubF, e0, e1); 994 if (isa<complex::SubOp>(def)) 995 return addExp(kSubC, e0, e1); 996 if (isa<arith::SubIOp>(def)) 997 return addExp(kSubI, e0, e1); 998 if (isa<arith::AndIOp>(def)) 999 return addExp(kAndI, e0, e1); 1000 if (isa<arith::OrIOp>(def)) 1001 return addExp(kOrI, e0, e1); 1002 if (isa<arith::XOrIOp>(def)) 1003 return addExp(kXorI, e0, e1); 1004 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 1005 return addExp(kShrS, e0, e1); 1006 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 1007 return addExp(kShrU, e0, e1); 1008 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 1009 return addExp(kShlI, e0, e1); 1010 if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) { 1011 if (isAdmissableBranch(binop, binop.getOverlapRegion()) && 1012 (binop.getLeftIdentity() || 1013 isAdmissableBranch(binop, binop.getLeftRegion())) && 1014 (binop.getRightIdentity() || 1015 isAdmissableBranch(binop, binop.getRightRegion()))) 1016 return addExp(kBinary, e0, e1, Value(), def); 1017 } 1018 } 1019 } 1020 // Cannot build. 1021 return None; 1022 } 1023 1024 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 1025 ValueRange vals) { 1026 // Make a clone of overlap region. 1027 Region tmpRegion; 1028 BlockAndValueMapping mapper; 1029 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 1030 Block &clonedBlock = tmpRegion.front(); 1031 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 1032 // Merge cloned block and return yield value. 1033 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 1034 rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); 1035 Value val = clonedYield.getResult(); 1036 rewriter.eraseOp(clonedYield); 1037 rewriter.eraseOp(placeholder); 1038 return val; 1039 } 1040 1041 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 1042 Operation *op, Value v0) { 1043 if (!v0) 1044 // Empty input value must be propagated. 1045 return Value(); 1046 UnaryOp unop = cast<UnaryOp>(op); 1047 Region &presentRegion = unop.getPresentRegion(); 1048 if (presentRegion.empty()) 1049 // Uninitialized Value() will be interpreted as missing data in the 1050 // output. 1051 return Value(); 1052 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 1053 } 1054 1055 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 1056 Operation *op, Value v0, Value v1) { 1057 if (!v0 || !v1) 1058 // Empty input values must be propagated. 1059 return Value(); 1060 BinaryOp binop = cast<BinaryOp>(op); 1061 Region &overlapRegion = binop.getOverlapRegion(); 1062 if (overlapRegion.empty()) 1063 // Uninitialized Value() will be interpreted as missing data in the 1064 // output. 1065 return Value(); 1066 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 1067 } 1068 1069 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, 1070 Value v0, Value v1) { 1071 switch (tensorExps[e].kind) { 1072 // Leaf. 1073 case kTensor: 1074 case kInvariant: 1075 case kIndex: 1076 llvm_unreachable("unexpected non-op"); 1077 // Unary operations. 1078 case kAbsF: 1079 return rewriter.create<math::AbsOp>(loc, v0); 1080 case kAbsC: { 1081 auto type = v0.getType().cast<ComplexType>(); 1082 auto eltType = type.getElementType().cast<FloatType>(); 1083 return rewriter.create<complex::AbsOp>(loc, eltType, v0); 1084 } 1085 case kCeilF: 1086 return rewriter.create<math::CeilOp>(loc, v0); 1087 case kFloorF: 1088 return rewriter.create<math::FloorOp>(loc, v0); 1089 case kSqrtF: 1090 return rewriter.create<math::SqrtOp>(loc, v0); 1091 case kSqrtC: 1092 return rewriter.create<complex::SqrtOp>(loc, v0); 1093 case kExpm1F: 1094 return rewriter.create<math::ExpM1Op>(loc, v0); 1095 case kExpm1C: 1096 return rewriter.create<complex::Expm1Op>(loc, v0); 1097 case kLog1pF: 1098 return rewriter.create<math::Log1pOp>(loc, v0); 1099 case kLog1pC: 1100 return rewriter.create<complex::Log1pOp>(loc, v0); 1101 case kSinF: 1102 return rewriter.create<math::SinOp>(loc, v0); 1103 case kSinC: 1104 return rewriter.create<complex::SinOp>(loc, v0); 1105 case kTanhF: 1106 return rewriter.create<math::TanhOp>(loc, v0); 1107 case kTanhC: 1108 return rewriter.create<complex::TanhOp>(loc, v0); 1109 case kNegF: 1110 return rewriter.create<arith::NegFOp>(loc, v0); 1111 case kNegC: 1112 return rewriter.create<complex::NegOp>(loc, v0); 1113 case kNegI: // no negi in std 1114 return rewriter.create<arith::SubIOp>( 1115 loc, 1116 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 1117 rewriter.getZeroAttr(v0.getType())), 1118 v0); 1119 case kTruncF: 1120 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 1121 case kExtF: 1122 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 1123 case kCastFS: 1124 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 1125 case kCastFU: 1126 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 1127 case kCastSF: 1128 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 1129 case kCastUF: 1130 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 1131 case kCastS: 1132 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 1133 case kCastU: 1134 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 1135 case kCastIdx: 1136 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 1137 case kTruncI: 1138 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 1139 case kCIm: { 1140 auto type = v0.getType().cast<ComplexType>(); 1141 auto eltType = type.getElementType().cast<FloatType>(); 1142 return rewriter.create<complex::ImOp>(loc, eltType, v0); 1143 } 1144 case kCRe: { 1145 auto type = v0.getType().cast<ComplexType>(); 1146 auto eltType = type.getElementType().cast<FloatType>(); 1147 return rewriter.create<complex::ReOp>(loc, eltType, v0); 1148 } 1149 case kBitCast: 1150 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 1151 // Binary operations. 1152 case kMulF: 1153 return rewriter.create<arith::MulFOp>(loc, v0, v1); 1154 case kMulC: 1155 return rewriter.create<complex::MulOp>(loc, v0, v1); 1156 case kMulI: 1157 return rewriter.create<arith::MulIOp>(loc, v0, v1); 1158 case kDivF: 1159 return rewriter.create<arith::DivFOp>(loc, v0, v1); 1160 case kDivC: 1161 return rewriter.create<complex::DivOp>(loc, v0, v1); 1162 case kDivS: 1163 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 1164 case kDivU: 1165 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 1166 case kAddF: 1167 return rewriter.create<arith::AddFOp>(loc, v0, v1); 1168 case kAddC: 1169 return rewriter.create<complex::AddOp>(loc, v0, v1); 1170 case kAddI: 1171 return rewriter.create<arith::AddIOp>(loc, v0, v1); 1172 case kSubF: 1173 return rewriter.create<arith::SubFOp>(loc, v0, v1); 1174 case kSubC: 1175 return rewriter.create<complex::SubOp>(loc, v0, v1); 1176 case kSubI: 1177 return rewriter.create<arith::SubIOp>(loc, v0, v1); 1178 case kAndI: 1179 return rewriter.create<arith::AndIOp>(loc, v0, v1); 1180 case kOrI: 1181 return rewriter.create<arith::OrIOp>(loc, v0, v1); 1182 case kXorI: 1183 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 1184 case kShrS: 1185 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 1186 case kShrU: 1187 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 1188 case kShlI: 1189 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 1190 case kBinaryBranch: // semi-ring ops with custom logic. 1191 return insertYieldOp(rewriter, loc, 1192 *tensorExps[e].op->getBlock()->getParent(), {v0}); 1193 case kUnary: 1194 return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); 1195 case kBinary: 1196 return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); 1197 } 1198 llvm_unreachable("unexpected expression kind in build"); 1199 } 1200 1201 } // namespace sparse_tensor 1202 } // namespace mlir 1203