1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 18 namespace mlir { 19 namespace sparse_tensor { 20 21 //===----------------------------------------------------------------------===// 22 // Constructors. 23 //===----------------------------------------------------------------------===// 24 25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) 26 : kind(k), val(v), op(o) { 27 switch (kind) { 28 case kTensor: 29 assert(x != -1u && y == -1u && !v && !o); 30 tensor = x; 31 break; 32 case kInvariant: 33 assert(x == -1u && y == -1u && v && !o); 34 break; 35 case kIndex: 36 assert(x != -1u && y == -1u && !v && !o); 37 index = x; 38 break; 39 case kAbsF: 40 case kAbsC: 41 case kCeilF: 42 case kFloorF: 43 case kSqrtF: 44 case kSqrtC: 45 case kExpm1F: 46 case kExpm1C: 47 case kLog1pF: 48 case kLog1pC: 49 case kSinF: 50 case kSinC: 51 case kTanhF: 52 case kTanhC: 53 case kNegF: 54 case kNegC: 55 case kNegI: 56 case kCIm: 57 case kCRe: 58 assert(x != -1u && y == -1u && !v && !o); 59 children.e0 = x; 60 children.e1 = y; 61 break; 62 case kTruncF: 63 case kExtF: 64 case kCastFS: 65 case kCastFU: 66 case kCastSF: 67 case kCastUF: 68 case kCastS: 69 case kCastU: 70 case kCastIdx: 71 case kTruncI: 72 case kBitCast: 73 assert(x != -1u && y == -1u && v && !o); 74 children.e0 = x; 75 children.e1 = y; 76 break; 77 case kBinaryBranch: 78 assert(x != -1u && y == -1u && !v && o); 79 children.e0 = x; 80 children.e1 = y; 81 break; 82 case kUnary: 83 // No assertion on y can be made, as the branching paths involve both 84 // a unary (mapSet) and binary (takeDisj) pathway. 85 assert(x != -1u && !v && o); 86 children.e0 = x; 87 children.e1 = y; 88 break; 89 case kBinary: 90 assert(x != -1u && y != -1u && !v && o); 91 children.e0 = x; 92 children.e1 = y; 93 break; 94 default: 95 assert(x != -1u && y != -1u && !v && !o); 96 children.e0 = x; 97 children.e1 = y; 98 break; 99 } 100 } 101 102 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 103 : bits(n, false), simple(), exp(e) { 104 bits.set(b); 105 } 106 107 LatPoint::LatPoint(const BitVector &b, unsigned e) 108 : bits(b), simple(), exp(e) {} 109 110 //===----------------------------------------------------------------------===// 111 // Lattice methods. 112 //===----------------------------------------------------------------------===// 113 114 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, 115 Operation *op) { 116 unsigned e = tensorExps.size(); 117 tensorExps.push_back(TensorExp(k, e0, e1, v, op)); 118 return e; 119 } 120 121 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 122 assert(t < numTensors && i < numLoops); 123 unsigned p = latPoints.size(); 124 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 125 return p; 126 } 127 128 unsigned Merger::addSet() { 129 unsigned s = latSets.size(); 130 latSets.emplace_back(SmallVector<unsigned, 16>()); 131 return s; 132 } 133 134 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, 135 Operation *op) { 136 unsigned p = latPoints.size(); 137 BitVector nb = BitVector(latPoints[p0].bits); 138 nb |= latPoints[p1].bits; 139 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); 140 latPoints.push_back(LatPoint(nb, e)); 141 return p; 142 } 143 144 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 145 unsigned s = addSet(); 146 for (unsigned p0 : latSets[s0]) 147 for (unsigned p1 : latSets[s1]) 148 latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); 149 return s; 150 } 151 152 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 153 unsigned s = takeConj(kind, s0, s1, op); 154 // Followed by all in s0. 155 for (unsigned p : latSets[s0]) 156 latSets[s].push_back(p); 157 // Map binary 0-y to unary -y. 158 // TODO: move this if-else logic into buildLattices 159 if (kind == kSubF) 160 s1 = mapSet(kNegF, s1); 161 else if (kind == kSubC) 162 s1 = mapSet(kNegC, s1); 163 else if (kind == kSubI) 164 s1 = mapSet(kNegI, s1); 165 // Followed by all in s1. 166 for (unsigned p : latSets[s1]) 167 latSets[s].push_back(p); 168 return s; 169 } 170 171 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, 172 bool includeLeft, Kind ltrans, Operation *opleft, 173 bool includeRight, Kind rtrans, Operation *opright) { 174 unsigned s = takeConj(kind, s0, s1, orig); 175 // Left Region. 176 if (includeLeft) { 177 if (opleft) 178 s0 = mapSet(ltrans, s0, Value(), opleft); 179 for (unsigned p : latSets[s0]) 180 latSets[s].push_back(p); 181 } 182 // Right Region. 183 if (includeRight) { 184 if (opright) 185 s1 = mapSet(rtrans, s1, Value(), opright); 186 for (unsigned p : latSets[s1]) 187 latSets[s].push_back(p); 188 } 189 return s; 190 } 191 192 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { 193 assert(kAbsF <= kind && kind <= kUnary); 194 unsigned s = addSet(); 195 for (unsigned p : latSets[s0]) { 196 unsigned e = addExp(kind, latPoints[p].exp, v, op); 197 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 198 latSets[s].push_back(latPoints.size() - 1); 199 } 200 return s; 201 } 202 203 unsigned Merger::optimizeSet(unsigned s0) { 204 unsigned s = addSet(); 205 assert(!latSets[s0].empty()); 206 unsigned p0 = latSets[s0][0]; 207 for (unsigned p1 : latSets[s0]) { 208 bool add = true; 209 if (p0 != p1) { 210 // Is this a straightforward copy? 211 unsigned e = latPoints[p1].exp; 212 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 213 continue; 214 // Conjunction already covered? 215 for (unsigned p2 : latSets[s]) { 216 assert(!latGT(p1, p2)); // Lj => Li would be bad 217 if (onlyDenseDiff(p2, p1)) { 218 add = false; 219 break; 220 } 221 } 222 assert(!add || latGT(p0, p1)); 223 } 224 if (add) 225 latSets[s].push_back(p1); 226 } 227 for (unsigned p : latSets[s]) 228 latPoints[p].simple = simplifyCond(s, p); 229 return s; 230 } 231 232 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 233 // First determine if this lattice point is a *singleton*, i.e., 234 // the last point in a lattice, no other is less than this one. 235 bool isSingleton = true; 236 for (unsigned p1 : latSets[s0]) { 237 if (p0 != p1 && latGT(p0, p1)) { 238 isSingleton = false; 239 break; 240 } 241 } 242 // Now apply the two basic rules. 243 BitVector simple = latPoints[p0].bits; 244 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 245 for (unsigned b = 0, be = simple.size(); b < be; b++) { 246 if (simple[b] && !isDim(b, kSparse)) { 247 if (reset) 248 simple.reset(b); 249 reset = true; 250 } 251 } 252 return simple; 253 } 254 255 bool Merger::latGT(unsigned i, unsigned j) const { 256 const BitVector &bitsi = latPoints[i].bits; 257 const BitVector &bitsj = latPoints[j].bits; 258 assert(bitsi.size() == bitsj.size()); 259 if (bitsi.count() > bitsj.count()) { 260 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 261 if (bitsj[b] && !bitsi[b]) 262 return false; 263 return true; 264 } 265 return false; 266 } 267 268 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 269 BitVector tmp = latPoints[j].bits; 270 tmp ^= latPoints[i].bits; 271 return !hasAnyDimOf(tmp, kSparse); 272 } 273 274 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { 275 for (unsigned b = 0, be = bits.size(); b < be; b++) 276 if (bits[b] && isDim(b, d)) 277 return true; 278 return false; 279 } 280 281 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 282 switch (tensorExps[e].kind) { 283 case kTensor: 284 return tensorExps[e].tensor == t; 285 case kAbsF: 286 case kAbsC: 287 case kCeilF: 288 case kFloorF: 289 case kSqrtF: 290 case kSqrtC: 291 case kExpm1F: 292 case kExpm1C: 293 case kLog1pF: 294 case kLog1pC: 295 case kSinF: 296 case kSinC: 297 case kTanhF: 298 case kTanhC: 299 case kNegF: 300 case kNegC: 301 case kNegI: 302 case kTruncF: 303 case kExtF: 304 case kCastFS: 305 case kCastFU: 306 case kCastSF: 307 case kCastUF: 308 case kCastS: 309 case kCastU: 310 case kCastIdx: 311 case kTruncI: 312 case kCIm: 313 case kCRe: 314 case kBitCast: 315 return isSingleCondition(t, tensorExps[e].children.e0); 316 case kDivF: // note: x / c only 317 case kDivC: 318 case kDivS: 319 case kDivU: 320 assert(!maybeZero(tensorExps[e].children.e1)); 321 return isSingleCondition(t, tensorExps[e].children.e0); 322 case kShrS: // note: x >> inv only 323 case kShrU: 324 case kShlI: 325 assert(isInvariant(tensorExps[e].children.e1)); 326 return isSingleCondition(t, tensorExps[e].children.e0); 327 case kMulF: 328 case kMulC: 329 case kMulI: 330 case kAndI: 331 if (isSingleCondition(t, tensorExps[e].children.e0)) 332 return isSingleCondition(t, tensorExps[e].children.e1) || 333 isInvariant(tensorExps[e].children.e1); 334 if (isSingleCondition(t, tensorExps[e].children.e1)) 335 return isInvariant(tensorExps[e].children.e0); 336 return false; 337 case kAddF: 338 case kAddC: 339 case kAddI: 340 return isSingleCondition(t, tensorExps[e].children.e0) && 341 isSingleCondition(t, tensorExps[e].children.e1); 342 default: 343 return false; 344 } 345 } 346 347 #ifndef NDEBUG 348 349 //===----------------------------------------------------------------------===// 350 // Print methods (for debugging). 351 //===----------------------------------------------------------------------===// 352 353 static const char *kindToOpSymbol(Kind kind) { 354 switch (kind) { 355 case kTensor: 356 return "tensor"; 357 case kInvariant: 358 return "invariant"; 359 case kIndex: 360 return "index"; 361 case kAbsF: 362 case kAbsC: 363 return "abs"; 364 case kCeilF: 365 return "ceil"; 366 case kFloorF: 367 return "floor"; 368 case kSqrtF: 369 case kSqrtC: 370 return "sqrt"; 371 case kExpm1F: 372 case kExpm1C: 373 return "expm1"; 374 case kLog1pF: 375 case kLog1pC: 376 return "log1p"; 377 case kSinF: 378 case kSinC: 379 return "sin"; 380 case kTanhF: 381 case kTanhC: 382 return "tanh"; 383 case kNegF: 384 case kNegC: 385 case kNegI: 386 return "-"; 387 case kTruncF: 388 case kExtF: 389 case kCastFS: 390 case kCastFU: 391 case kCastSF: 392 case kCastUF: 393 case kCastS: 394 case kCastU: 395 case kCastIdx: 396 case kTruncI: 397 case kCIm: 398 return "complex.im"; 399 case kCRe: 400 return "complex.re"; 401 case kBitCast: 402 return "cast"; 403 case kBinaryBranch: 404 return "binary_branch"; 405 case kUnary: 406 return "unary"; 407 case kMulF: 408 case kMulC: 409 case kMulI: 410 return "*"; 411 case kDivF: 412 case kDivC: 413 case kDivS: 414 case kDivU: 415 return "/"; 416 case kAddF: 417 case kAddC: 418 case kAddI: 419 return "+"; 420 case kSubF: 421 case kSubC: 422 case kSubI: 423 return "-"; 424 case kAndI: 425 return "&"; 426 case kOrI: 427 return "|"; 428 case kXorI: 429 return "^"; 430 case kShrS: 431 return "a>>"; 432 case kShrU: 433 return ">>"; 434 case kShlI: 435 return "<<"; 436 case kBinary: 437 return "binary"; 438 } 439 llvm_unreachable("unexpected kind for symbol"); 440 } 441 442 void Merger::dumpExp(unsigned e) const { 443 switch (tensorExps[e].kind) { 444 case kTensor: 445 if (tensorExps[e].tensor == syntheticTensor) 446 llvm::dbgs() << "synthetic_"; 447 else if (tensorExps[e].tensor == outTensor) 448 llvm::dbgs() << "output_"; 449 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 450 break; 451 case kInvariant: 452 llvm::dbgs() << "invariant"; 453 break; 454 case kIndex: 455 llvm::dbgs() << "index_" << tensorExps[e].index; 456 break; 457 case kAbsF: 458 case kCeilF: 459 case kFloorF: 460 case kSqrtF: 461 case kSqrtC: 462 case kExpm1F: 463 case kExpm1C: 464 case kLog1pF: 465 case kSinF: 466 case kTanhF: 467 case kTanhC: 468 case kNegF: 469 case kNegI: 470 case kTruncF: 471 case kExtF: 472 case kCastFS: 473 case kCastFU: 474 case kCastSF: 475 case kCastUF: 476 case kCastS: 477 case kCastU: 478 case kCastIdx: 479 case kTruncI: 480 case kBitCast: 481 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 482 dumpExp(tensorExps[e].children.e0); 483 break; 484 default: 485 llvm::dbgs() << "("; 486 dumpExp(tensorExps[e].children.e0); 487 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 488 dumpExp(tensorExps[e].children.e1); 489 llvm::dbgs() << ")"; 490 } 491 } 492 493 void Merger::dumpLat(unsigned p) const { 494 llvm::dbgs() << "lat("; 495 dumpBits(latPoints[p].bits); 496 llvm::dbgs() << " :"; 497 dumpBits(latPoints[p].simple); 498 llvm::dbgs() << " : "; 499 dumpExp(latPoints[p].exp); 500 llvm::dbgs() << " )\n"; 501 } 502 503 void Merger::dumpSet(unsigned s) const { 504 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 505 for (unsigned p : latSets[s]) { 506 llvm::dbgs() << " "; 507 dumpLat(p); 508 } 509 llvm::dbgs() << "}\n"; 510 } 511 512 void Merger::dumpBits(const BitVector &bits) const { 513 for (unsigned b = 0, be = bits.size(); b < be; b++) { 514 if (bits[b]) { 515 unsigned t = tensor(b); 516 unsigned i = index(b); 517 llvm::dbgs() << " i_" << t << "_" << i << "_"; 518 switch (dims[t][i]) { 519 case kSparse: 520 llvm::dbgs() << "S"; 521 break; 522 case kDense: 523 llvm::dbgs() << "D"; 524 break; 525 case kSingle: 526 llvm::dbgs() << "T"; 527 break; 528 case kUndef: 529 llvm::dbgs() << "U"; 530 break; 531 } 532 } 533 } 534 } 535 536 #endif // NDEBUG 537 538 //===----------------------------------------------------------------------===// 539 // Builder methods. 540 //===----------------------------------------------------------------------===// 541 542 unsigned Merger::buildLattices(unsigned e, unsigned i) { 543 Kind kind = tensorExps[e].kind; 544 switch (kind) { 545 case kTensor: 546 case kInvariant: 547 case kIndex: { 548 // Either the index is really used in the tensor expression, or it is 549 // set to the undefined index in that dimension. An invariant expression, 550 // a proper index value, and a truly dynamic sparse output tensor are set 551 // to a synthetic tensor with undefined indices only to ensure the 552 // iteration space is not skipped as a result of their contents. 553 unsigned s = addSet(); 554 unsigned t = syntheticTensor; 555 if (kind == kTensor) { 556 t = tensorExps[e].tensor; 557 if (hasSparseOut && t == outTensor) 558 t = syntheticTensor; 559 } 560 latSets[s].push_back(addLat(t, i, e)); 561 return s; 562 } 563 case kAbsF: 564 case kAbsC: 565 case kCeilF: 566 case kCIm: 567 case kCRe: 568 case kFloorF: 569 case kSqrtF: 570 case kSqrtC: 571 case kExpm1F: 572 case kExpm1C: 573 case kLog1pF: 574 case kLog1pC: 575 case kSinF: 576 case kSinC: 577 case kTanhF: 578 case kTanhC: 579 case kNegF: 580 case kNegC: 581 case kNegI: 582 case kTruncF: 583 case kExtF: 584 case kCastFS: 585 case kCastFU: 586 case kCastSF: 587 case kCastUF: 588 case kCastS: 589 case kCastU: 590 case kCastIdx: 591 case kTruncI: 592 case kBitCast: 593 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 594 // lattice set of the operand through the operator into a new set. 595 // 596 // -y|!y | y | 597 // --+---+---+ 598 // | 0 |-y | 599 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 600 tensorExps[e].val); 601 case kBinaryBranch: 602 // The left or right half of a binary operation which has already 603 // been split into separate operations for each region. 604 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), 605 tensorExps[e].op); 606 case kUnary: 607 // A custom unary operation. 608 // 609 // op y| !y | y | 610 // ----+----------+------------+ 611 // | absent() | present(y) | 612 { 613 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 614 UnaryOp unop = cast<UnaryOp>(tensorExps[e].op); 615 Region &absentRegion = unop.absentRegion(); 616 617 if (absentRegion.empty()) { 618 // Simple mapping over existing values. 619 return mapSet(kind, child0, Value(), unop); 620 } // Use a disjunction with `unop` on the left and the absent value as an 621 // invariant on the right. 622 Block &absentBlock = absentRegion.front(); 623 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 624 Value absentVal = absentYield.result(); 625 unsigned rhs = addExp(kInvariant, absentVal); 626 return takeDisj(kind, child0, buildLattices(rhs, i), unop); 627 } 628 case kMulF: 629 case kMulC: 630 case kMulI: 631 case kAndI: 632 // A multiplicative operation only needs to be performed 633 // for the conjunction of sparse iteration spaces. 634 // 635 // x*y|!y | y | 636 // ---+---+---+ 637 // !x | 0 | 0 | 638 // x | 0 |x*y| 639 // 640 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 641 return takeConj(kind, // take binary conjunction 642 buildLattices(tensorExps[e].children.e0, i), 643 buildLattices(tensorExps[e].children.e1, i)); 644 case kDivF: 645 case kDivC: 646 case kDivS: 647 case kDivU: 648 // A division is tricky, since 0/0, 0/c, c/0 all have 649 // specific outcomes for floating-point and integers. 650 // Thus, we need to traverse the full iteration space. 651 // 652 // x/y|!y | y | 653 // ---+---+---+ 654 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 655 // x |x/0|x/y| INT: x/0=exception for any x 656 // 657 // TODO: for now we "fixed" this by only accepting x/c cases 658 // during expression building, so that the conjunction 659 // rules applies (viz. x/c = x*(1/c) as far as lattice 660 // construction is concerned). 661 assert(!maybeZero(tensorExps[e].children.e1)); 662 return takeConj(kind, // take binary conjunction 663 buildLattices(tensorExps[e].children.e0, i), 664 buildLattices(tensorExps[e].children.e1, i)); 665 case kAddF: 666 case kAddC: 667 case kAddI: 668 case kSubF: 669 case kSubC: 670 case kSubI: 671 case kOrI: 672 case kXorI: 673 // An additive operation needs to be performed 674 // for the disjunction of sparse iteration spaces. 675 // 676 // x+y|!y | y | x-y|!y | y | 677 // ---+---+---+ ---+---+---+ 678 // !x | 0 | y | !x | 0 |-y | 679 // x | x |x+y| x | x |x-y| 680 return takeDisj(kind, // take binary disjunction 681 buildLattices(tensorExps[e].children.e0, i), 682 buildLattices(tensorExps[e].children.e1, i)); 683 case kShrS: 684 case kShrU: 685 case kShlI: 686 // A shift operation by an invariant amount (viz. tensor expressions 687 // can only occur at the left-hand-side of the operator) can be handled 688 // with the conjuction rule. 689 assert(isInvariant(tensorExps[e].children.e1)); 690 return takeConj(kind, // take binary conjunction 691 buildLattices(tensorExps[e].children.e0, i), 692 buildLattices(tensorExps[e].children.e1, i)); 693 case kBinary: 694 // A custom binary operation. 695 // 696 // x op y| !y | y | 697 // ------+---------+--------------+ 698 // !x | empty | right(y) | 699 // x | left(x) | overlap(x,y) | 700 { 701 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 702 unsigned child1 = buildLattices(tensorExps[e].children.e1, i); 703 BinaryOp binop = cast<BinaryOp>(tensorExps[e].op); 704 Region &leftRegion = binop.leftRegion(); 705 Region &rightRegion = binop.rightRegion(); 706 // Left Region. 707 Operation *leftYield = nullptr; 708 if (!leftRegion.empty()) { 709 Block &leftBlock = leftRegion.front(); 710 leftYield = leftBlock.getTerminator(); 711 } 712 // Right Region. 713 Operation *rightYield = nullptr; 714 if (!rightRegion.empty()) { 715 Block &rightBlock = rightRegion.front(); 716 rightYield = rightBlock.getTerminator(); 717 } 718 bool includeLeft = binop.left_identity() || !leftRegion.empty(); 719 bool includeRight = binop.right_identity() || !rightRegion.empty(); 720 return takeCombi(kBinary, child0, child1, binop, includeLeft, 721 kBinaryBranch, leftYield, includeRight, kBinaryBranch, 722 rightYield); 723 } 724 } 725 llvm_unreachable("unexpected expression kind"); 726 } 727 728 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 729 Operation *yield = op.region().front().getTerminator(); 730 return buildTensorExp(op, yield->getOperand(0)); 731 } 732 733 /// Only returns false if we are certain this is a nonzero. 734 bool Merger::maybeZero(unsigned e) const { 735 if (tensorExps[e].kind == kInvariant) { 736 if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) { 737 ArrayAttr arrayAttr = c.getValue(); 738 return arrayAttr[0].cast<FloatAttr>().getValue().isZero() && 739 arrayAttr[0].cast<FloatAttr>().getValue().isZero(); 740 } 741 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 742 return c.value() == 0; 743 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 744 return c.value().isZero(); 745 } 746 return true; 747 } 748 749 bool Merger::isInvariant(unsigned e) const { 750 return tensorExps[e].kind == kInvariant; 751 } 752 753 Type Merger::inferType(unsigned e, Value src) { 754 // Obtain the destination type from the cast node. 755 Type dtp = tensorExps[e].val.getType(); 756 // Inspect source type. For vector types, apply the same 757 // vectorization to the destination type. 758 if (auto vtp = src.getType().dyn_cast<VectorType>()) 759 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 760 return dtp; 761 } 762 763 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 764 if (auto arg = v.dyn_cast<BlockArgument>()) { 765 unsigned argN = arg.getArgNumber(); 766 // Any argument of the generic op that is not marked as a scalar 767 // argument is considered a tensor, indexed by the implicit loop 768 // bounds. This includes rank-0 tensor arguments. 769 if (arg.getOwner()->getParentOp() == op) { 770 OpOperand *t = op.getInputAndOutputOperands()[argN]; 771 if (!op.isScalar(t)) 772 return addExp(kTensor, argN); 773 v = t->get(); // get scalar value 774 } 775 // Any other argument (marked as scalar argument for the generic op 776 // or belonging to an enveloping op) is considered invariant. 777 return addExp(kInvariant, v); 778 } 779 // Something defined outside is invariant. 780 Operation *def = v.getDefiningOp(); 781 if (def->getBlock() != &op.region().front()) 782 return addExp(kInvariant, v); 783 // Construct index operations. 784 if (def->getNumOperands() == 0) { 785 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 786 return addExp(kIndex, indexOp.dim()); 787 } 788 // Construct unary operations if subexpression can be built. 789 if (def->getNumOperands() == 1) { 790 auto x = buildTensorExp(op, def->getOperand(0)); 791 if (x.hasValue()) { 792 unsigned e = x.getValue(); 793 if (isa<math::AbsOp>(def)) 794 return addExp(kAbsF, e); 795 if (isa<complex::AbsOp>(def)) 796 return addExp(kAbsC, e); 797 if (isa<math::CeilOp>(def)) 798 return addExp(kCeilF, e); 799 if (isa<math::FloorOp>(def)) 800 return addExp(kFloorF, e); 801 if (isa<math::SqrtOp>(def)) 802 return addExp(kSqrtF, e); 803 if (isa<complex::SqrtOp>(def)) 804 return addExp(kSqrtC, e); 805 if (isa<math::ExpM1Op>(def)) 806 return addExp(kExpm1F, e); 807 if (isa<complex::Expm1Op>(def)) 808 return addExp(kExpm1C, e); 809 if (isa<math::Log1pOp>(def)) 810 return addExp(kLog1pF, e); 811 if (isa<complex::Log1pOp>(def)) 812 return addExp(kLog1pC, e); 813 if (isa<math::SinOp>(def)) 814 return addExp(kSinF, e); 815 if (isa<complex::SinOp>(def)) 816 return addExp(kSinC, e); 817 if (isa<math::TanhOp>(def)) 818 return addExp(kTanhF, e); 819 if (isa<complex::TanhOp>(def)) 820 return addExp(kTanhC, e); 821 if (isa<arith::NegFOp>(def)) 822 return addExp(kNegF, e); // no negi in std 823 if (isa<complex::NegOp>(def)) 824 return addExp(kNegC, e); 825 if (isa<arith::TruncFOp>(def)) 826 return addExp(kTruncF, e, v); 827 if (isa<arith::ExtFOp>(def)) 828 return addExp(kExtF, e, v); 829 if (isa<arith::FPToSIOp>(def)) 830 return addExp(kCastFS, e, v); 831 if (isa<arith::FPToUIOp>(def)) 832 return addExp(kCastFU, e, v); 833 if (isa<arith::SIToFPOp>(def)) 834 return addExp(kCastSF, e, v); 835 if (isa<arith::UIToFPOp>(def)) 836 return addExp(kCastUF, e, v); 837 if (isa<arith::ExtSIOp>(def)) 838 return addExp(kCastS, e, v); 839 if (isa<arith::ExtUIOp>(def)) 840 return addExp(kCastU, e, v); 841 if (isa<arith::IndexCastOp>(def)) 842 return addExp(kCastIdx, e, v); 843 if (isa<arith::TruncIOp>(def)) 844 return addExp(kTruncI, e, v); 845 if (isa<complex::ImOp>(def)) 846 return addExp(kCIm, e); 847 if (isa<complex::ReOp>(def)) 848 return addExp(kCRe, e); 849 if (isa<arith::BitcastOp>(def)) 850 return addExp(kBitCast, e, v); 851 if (isa<sparse_tensor::UnaryOp>(def)) 852 return addExp(kUnary, e, Value(), def); 853 } 854 } 855 // Construct binary operations if subexpressions can be built. 856 // See buildLattices() for an explanation of rejecting certain 857 // division and shift operations 858 if (def->getNumOperands() == 2) { 859 auto x = buildTensorExp(op, def->getOperand(0)); 860 auto y = buildTensorExp(op, def->getOperand(1)); 861 if (x.hasValue() && y.hasValue()) { 862 unsigned e0 = x.getValue(); 863 unsigned e1 = y.getValue(); 864 if (isa<arith::MulFOp>(def)) 865 return addExp(kMulF, e0, e1); 866 if (isa<complex::MulOp>(def)) 867 return addExp(kMulC, e0, e1); 868 if (isa<arith::MulIOp>(def)) 869 return addExp(kMulI, e0, e1); 870 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 871 return addExp(kDivF, e0, e1); 872 if (isa<complex::DivOp>(def) && !maybeZero(e1)) 873 return addExp(kDivC, e0, e1); 874 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 875 return addExp(kDivS, e0, e1); 876 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 877 return addExp(kDivU, e0, e1); 878 if (isa<arith::AddFOp>(def)) 879 return addExp(kAddF, e0, e1); 880 if (isa<complex::AddOp>(def)) 881 return addExp(kAddC, e0, e1); 882 if (isa<arith::AddIOp>(def)) 883 return addExp(kAddI, e0, e1); 884 if (isa<arith::SubFOp>(def)) 885 return addExp(kSubF, e0, e1); 886 if (isa<complex::SubOp>(def)) 887 return addExp(kSubC, e0, e1); 888 if (isa<arith::SubIOp>(def)) 889 return addExp(kSubI, e0, e1); 890 if (isa<arith::AndIOp>(def)) 891 return addExp(kAndI, e0, e1); 892 if (isa<arith::OrIOp>(def)) 893 return addExp(kOrI, e0, e1); 894 if (isa<arith::XOrIOp>(def)) 895 return addExp(kXorI, e0, e1); 896 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 897 return addExp(kShrS, e0, e1); 898 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 899 return addExp(kShrU, e0, e1); 900 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 901 return addExp(kShlI, e0, e1); 902 if (isa<sparse_tensor::BinaryOp>(def)) 903 return addExp(kBinary, e0, e1, Value(), def); 904 } 905 } 906 // Cannot build. 907 return None; 908 } 909 910 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 911 ValueRange vals) { 912 // Make a clone of overlap region. 913 Region tmpRegion; 914 BlockAndValueMapping mapper; 915 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 916 Block &clonedBlock = tmpRegion.front(); 917 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 918 // Merge cloned block and return yield value. 919 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 920 rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); 921 Value val = clonedYield.result(); 922 rewriter.eraseOp(clonedYield); 923 rewriter.eraseOp(placeholder); 924 return val; 925 } 926 927 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 928 Operation *op, Value v0) { 929 if (!v0) 930 // Empty input value must be propagated. 931 return Value(); 932 UnaryOp unop = cast<UnaryOp>(op); 933 Region &presentRegion = unop.presentRegion(); 934 if (presentRegion.empty()) 935 // Uninitialized Value() will be interpreted as missing data in the 936 // output. 937 return Value(); 938 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 939 } 940 941 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 942 Operation *op, Value v0, Value v1) { 943 if (!v0 || !v1) 944 // Empty input values must be propagated. 945 return Value(); 946 BinaryOp binop = cast<BinaryOp>(op); 947 Region &overlapRegion = binop.overlapRegion(); 948 if (overlapRegion.empty()) 949 // Uninitialized Value() will be interpreted as missing data in the 950 // output. 951 return Value(); 952 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 953 } 954 955 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, 956 Value v0, Value v1) { 957 switch (tensorExps[e].kind) { 958 case kTensor: 959 case kInvariant: 960 case kIndex: 961 llvm_unreachable("unexpected non-op"); 962 // Unary ops. 963 case kAbsF: 964 return rewriter.create<math::AbsOp>(loc, v0); 965 case kAbsC: { 966 auto type = v0.getType().template cast<ComplexType>(); 967 auto eltType = type.getElementType().template cast<FloatType>(); 968 return rewriter.create<complex::AbsOp>(loc, eltType, v0); 969 } 970 case kCeilF: 971 return rewriter.create<math::CeilOp>(loc, v0); 972 case kFloorF: 973 return rewriter.create<math::FloorOp>(loc, v0); 974 case kSqrtF: 975 return rewriter.create<math::SqrtOp>(loc, v0); 976 case kSqrtC: 977 return rewriter.create<complex::SqrtOp>(loc, v0); 978 case kExpm1F: 979 return rewriter.create<math::ExpM1Op>(loc, v0); 980 case kExpm1C: 981 return rewriter.create<complex::Expm1Op>(loc, v0); 982 case kLog1pF: 983 return rewriter.create<math::Log1pOp>(loc, v0); 984 case kLog1pC: 985 return rewriter.create<complex::Log1pOp>(loc, v0); 986 case kSinF: 987 return rewriter.create<math::SinOp>(loc, v0); 988 case kSinC: 989 return rewriter.create<complex::SinOp>(loc, v0); 990 case kTanhF: 991 return rewriter.create<math::TanhOp>(loc, v0); 992 case kTanhC: 993 return rewriter.create<complex::TanhOp>(loc, v0); 994 case kNegF: 995 return rewriter.create<arith::NegFOp>(loc, v0); 996 case kNegC: 997 return rewriter.create<complex::NegOp>(loc, v0); 998 case kNegI: // no negi in std 999 return rewriter.create<arith::SubIOp>( 1000 loc, 1001 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 1002 rewriter.getZeroAttr(v0.getType())), 1003 v0); 1004 case kTruncF: 1005 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 1006 case kExtF: 1007 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 1008 case kCastFS: 1009 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 1010 case kCastFU: 1011 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 1012 case kCastSF: 1013 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 1014 case kCastUF: 1015 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 1016 case kCastS: 1017 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 1018 case kCastU: 1019 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 1020 case kCastIdx: 1021 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 1022 case kTruncI: 1023 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 1024 case kCIm: 1025 case kCRe: { 1026 auto type = v0.getType().template cast<ComplexType>(); 1027 auto eltType = type.getElementType().template cast<FloatType>(); 1028 if (tensorExps[e].kind == kCIm) 1029 return rewriter.create<complex::ImOp>(loc, eltType, v0); 1030 1031 return rewriter.create<complex::ReOp>(loc, eltType, v0); 1032 } 1033 case kBitCast: 1034 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 1035 // Binary ops. 1036 case kMulF: 1037 return rewriter.create<arith::MulFOp>(loc, v0, v1); 1038 case kMulC: 1039 return rewriter.create<complex::MulOp>(loc, v0, v1); 1040 case kMulI: 1041 return rewriter.create<arith::MulIOp>(loc, v0, v1); 1042 case kDivF: 1043 return rewriter.create<arith::DivFOp>(loc, v0, v1); 1044 case kDivC: 1045 return rewriter.create<complex::DivOp>(loc, v0, v1); 1046 case kDivS: 1047 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 1048 case kDivU: 1049 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 1050 case kAddF: 1051 return rewriter.create<arith::AddFOp>(loc, v0, v1); 1052 case kAddC: 1053 return rewriter.create<complex::AddOp>(loc, v0, v1); 1054 case kAddI: 1055 return rewriter.create<arith::AddIOp>(loc, v0, v1); 1056 case kSubF: 1057 return rewriter.create<arith::SubFOp>(loc, v0, v1); 1058 case kSubC: 1059 return rewriter.create<complex::SubOp>(loc, v0, v1); 1060 case kSubI: 1061 return rewriter.create<arith::SubIOp>(loc, v0, v1); 1062 case kAndI: 1063 return rewriter.create<arith::AndIOp>(loc, v0, v1); 1064 case kOrI: 1065 return rewriter.create<arith::OrIOp>(loc, v0, v1); 1066 case kXorI: 1067 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 1068 case kShrS: 1069 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 1070 case kShrU: 1071 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 1072 case kShlI: 1073 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 1074 // Semiring ops with custom logic. 1075 case kBinaryBranch: 1076 return insertYieldOp(rewriter, loc, 1077 *tensorExps[e].op->getBlock()->getParent(), {v0}); 1078 case kUnary: 1079 return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); 1080 case kBinary: 1081 return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); 1082 } 1083 llvm_unreachable("unexpected expression kind in build"); 1084 } 1085 1086 } // namespace sparse_tensor 1087 } // namespace mlir 1088