1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 14 #include "mlir/IR/Operation.h" 15 #include "llvm/Support/Debug.h" 16 17 namespace mlir { 18 namespace sparse_tensor { 19 20 //===----------------------------------------------------------------------===// 21 // Constructors. 22 //===----------------------------------------------------------------------===// 23 24 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) 25 : kind(k), val(v), op(o) { 26 switch (kind) { 27 case kTensor: 28 assert(x != -1u && y == -1u && !v && !o); 29 tensor = x; 30 break; 31 case kInvariant: 32 assert(x == -1u && y == -1u && v && !o); 33 break; 34 case kIndex: 35 assert(x != -1u && y == -1u && !v && !o); 36 index = x; 37 break; 38 case kAbsF: 39 case kCeilF: 40 case kFloorF: 41 case kSqrtF: 42 case kExpm1F: 43 case kLog1pF: 44 case kSinF: 45 case kTanhF: 46 case kNegF: 47 case kNegI: 48 assert(x != -1u && y == -1u && !v && !o); 49 children.e0 = x; 50 children.e1 = y; 51 break; 52 case kTruncF: 53 case kExtF: 54 case kCastFS: 55 case kCastFU: 56 case kCastSF: 57 case kCastUF: 58 case kCastS: 59 case kCastU: 60 case kCastIdx: 61 case kTruncI: 62 case kBitCast: 63 assert(x != -1u && y == -1u && v && !o); 64 children.e0 = x; 65 children.e1 = y; 66 break; 67 case kBinaryBranch: 68 assert(x != -1u && y == -1u && !v && o); 69 children.e0 = x; 70 children.e1 = y; 71 break; 72 case kUnary: 73 // No assertion on y can be made, as the branching paths involve both 74 // a unary (mapSet) and binary (takeDisj) pathway. 75 assert(x != -1u && !v && o); 76 children.e0 = x; 77 children.e1 = y; 78 break; 79 case kBinary: 80 assert(x != -1u && y != -1u && !v && o); 81 children.e0 = x; 82 children.e1 = y; 83 break; 84 default: 85 assert(x != -1u && y != -1u && !v && !o); 86 children.e0 = x; 87 children.e1 = y; 88 break; 89 } 90 } 91 92 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 93 : bits(n, false), simple(), exp(e) { 94 bits.set(b); 95 } 96 97 LatPoint::LatPoint(const BitVector &b, unsigned e) 98 : bits(b), simple(), exp(e) {} 99 100 //===----------------------------------------------------------------------===// 101 // Lattice methods. 102 //===----------------------------------------------------------------------===// 103 104 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, 105 Operation *op) { 106 unsigned e = tensorExps.size(); 107 tensorExps.push_back(TensorExp(k, e0, e1, v, op)); 108 return e; 109 } 110 111 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 112 assert(t < numTensors && i < numLoops); 113 unsigned p = latPoints.size(); 114 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 115 return p; 116 } 117 118 unsigned Merger::addSet() { 119 unsigned s = latSets.size(); 120 latSets.emplace_back(SmallVector<unsigned, 16>()); 121 return s; 122 } 123 124 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, 125 Operation *op) { 126 unsigned p = latPoints.size(); 127 BitVector nb = BitVector(latPoints[p0].bits); 128 nb |= latPoints[p1].bits; 129 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); 130 latPoints.push_back(LatPoint(nb, e)); 131 return p; 132 } 133 134 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 135 unsigned s = addSet(); 136 for (unsigned p0 : latSets[s0]) 137 for (unsigned p1 : latSets[s1]) 138 latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); 139 return s; 140 } 141 142 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 143 unsigned s = takeConj(kind, s0, s1, op); 144 // Followed by all in s0. 145 for (unsigned p : latSets[s0]) 146 latSets[s].push_back(p); 147 // Map binary 0-y to unary -y. 148 // TODO: move this if-else logic into buildLattices 149 if (kind == kSubF) 150 s1 = mapSet(kNegF, s1); 151 else if (kind == kSubI) 152 s1 = mapSet(kNegI, s1); 153 // Followed by all in s1. 154 for (unsigned p : latSets[s1]) 155 latSets[s].push_back(p); 156 return s; 157 } 158 159 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, 160 bool includeLeft, Kind ltrans, Operation *opleft, 161 bool includeRight, Kind rtrans, Operation *opright) { 162 unsigned s = takeConj(kind, s0, s1, orig); 163 // Left Region. 164 if (includeLeft) { 165 if (opleft) 166 s0 = mapSet(ltrans, s0, Value(), opleft); 167 for (unsigned p : latSets[s0]) 168 latSets[s].push_back(p); 169 } 170 // Right Region. 171 if (includeRight) { 172 if (opright) 173 s1 = mapSet(rtrans, s1, Value(), opright); 174 for (unsigned p : latSets[s1]) 175 latSets[s].push_back(p); 176 } 177 return s; 178 } 179 180 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { 181 assert(kAbsF <= kind && kind <= kUnary); 182 unsigned s = addSet(); 183 for (unsigned p : latSets[s0]) { 184 unsigned e = addExp(kind, latPoints[p].exp, v, op); 185 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 186 latSets[s].push_back(latPoints.size() - 1); 187 } 188 return s; 189 } 190 191 unsigned Merger::optimizeSet(unsigned s0) { 192 unsigned s = addSet(); 193 assert(!latSets[s0].empty()); 194 unsigned p0 = latSets[s0][0]; 195 for (unsigned p1 : latSets[s0]) { 196 bool add = true; 197 if (p0 != p1) { 198 // Is this a straightforward copy? 199 unsigned e = latPoints[p1].exp; 200 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 201 continue; 202 // Conjunction already covered? 203 for (unsigned p2 : latSets[s]) { 204 assert(!latGT(p1, p2)); // Lj => Li would be bad 205 if (onlyDenseDiff(p2, p1)) { 206 add = false; 207 break; 208 } 209 } 210 assert(!add || latGT(p0, p1)); 211 } 212 if (add) 213 latSets[s].push_back(p1); 214 } 215 for (unsigned p : latSets[s]) 216 latPoints[p].simple = simplifyCond(s, p); 217 return s; 218 } 219 220 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 221 // First determine if this lattice point is a *singleton*, i.e., 222 // the last point in a lattice, no other is less than this one. 223 bool isSingleton = true; 224 for (unsigned p1 : latSets[s0]) { 225 if (p0 != p1 && latGT(p0, p1)) { 226 isSingleton = false; 227 break; 228 } 229 } 230 // Now apply the two basic rules. 231 BitVector simple = latPoints[p0].bits; 232 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 233 for (unsigned b = 0, be = simple.size(); b < be; b++) { 234 if (simple[b] && !isDim(b, kSparse)) { 235 if (reset) 236 simple.reset(b); 237 reset = true; 238 } 239 } 240 return simple; 241 } 242 243 bool Merger::latGT(unsigned i, unsigned j) const { 244 const BitVector &bitsi = latPoints[i].bits; 245 const BitVector &bitsj = latPoints[j].bits; 246 assert(bitsi.size() == bitsj.size()); 247 if (bitsi.count() > bitsj.count()) { 248 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 249 if (bitsj[b] && !bitsi[b]) 250 return false; 251 return true; 252 } 253 return false; 254 } 255 256 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 257 BitVector tmp = latPoints[j].bits; 258 tmp ^= latPoints[i].bits; 259 return !hasAnyDimOf(tmp, kSparse); 260 } 261 262 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { 263 for (unsigned b = 0, be = bits.size(); b < be; b++) 264 if (bits[b] && isDim(b, d)) 265 return true; 266 return false; 267 } 268 269 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 270 switch (tensorExps[e].kind) { 271 case kTensor: 272 return tensorExps[e].tensor == t; 273 case kAbsF: 274 case kCeilF: 275 case kFloorF: 276 case kSqrtF: 277 case kExpm1F: 278 case kLog1pF: 279 case kSinF: 280 case kTanhF: 281 case kNegF: 282 case kNegI: 283 case kTruncF: 284 case kExtF: 285 case kCastFS: 286 case kCastFU: 287 case kCastSF: 288 case kCastUF: 289 case kCastS: 290 case kCastU: 291 case kCastIdx: 292 case kTruncI: 293 case kBitCast: 294 return isSingleCondition(t, tensorExps[e].children.e0); 295 case kDivF: // note: x / c only 296 case kDivS: 297 case kDivU: 298 assert(!maybeZero(tensorExps[e].children.e1)); 299 return isSingleCondition(t, tensorExps[e].children.e0); 300 case kShrS: // note: x >> inv only 301 case kShrU: 302 case kShlI: 303 assert(isInvariant(tensorExps[e].children.e1)); 304 return isSingleCondition(t, tensorExps[e].children.e0); 305 case kMulF: 306 case kMulI: 307 case kAndI: 308 if (isSingleCondition(t, tensorExps[e].children.e0)) 309 return isSingleCondition(t, tensorExps[e].children.e1) || 310 isInvariant(tensorExps[e].children.e1); 311 if (isSingleCondition(t, tensorExps[e].children.e1)) 312 return isInvariant(tensorExps[e].children.e0); 313 return false; 314 case kAddF: 315 case kAddI: 316 return isSingleCondition(t, tensorExps[e].children.e0) && 317 isSingleCondition(t, tensorExps[e].children.e1); 318 default: 319 return false; 320 } 321 } 322 323 #ifndef NDEBUG 324 325 //===----------------------------------------------------------------------===// 326 // Print methods (for debugging). 327 //===----------------------------------------------------------------------===// 328 329 static const char *kindToOpSymbol(Kind kind) { 330 switch (kind) { 331 case kTensor: 332 return "tensor"; 333 case kInvariant: 334 return "invariant"; 335 case kIndex: 336 return "index"; 337 case kAbsF: 338 return "abs"; 339 case kCeilF: 340 return "ceil"; 341 case kFloorF: 342 return "floor"; 343 case kSqrtF: 344 return "sqrt"; 345 case kExpm1F: 346 return "expm1"; 347 case kLog1pF: 348 return "log1p"; 349 case kSinF: 350 return "sin"; 351 case kTanhF: 352 return "tanh"; 353 case kNegF: 354 return "-"; 355 case kNegI: 356 return "-"; 357 case kTruncF: 358 case kExtF: 359 case kCastFS: 360 case kCastFU: 361 case kCastSF: 362 case kCastUF: 363 case kCastS: 364 case kCastU: 365 case kCastIdx: 366 case kTruncI: 367 case kBitCast: 368 return "cast"; 369 case kBinaryBranch: 370 return "binary_branch"; 371 case kUnary: 372 return "unary"; 373 case kMulF: 374 return "*"; 375 case kMulI: 376 return "*"; 377 case kDivF: 378 return "/"; 379 case kDivS: 380 return "/"; 381 case kDivU: 382 return "/"; 383 case kAddF: 384 return "+"; 385 case kAddI: 386 return "+"; 387 case kSubF: 388 return "-"; 389 case kSubI: 390 return "-"; 391 case kAndI: 392 return "&"; 393 case kOrI: 394 return "|"; 395 case kXorI: 396 return "^"; 397 case kShrS: 398 return "a>>"; 399 case kShrU: 400 return ">>"; 401 case kShlI: 402 return "<<"; 403 case kBinary: 404 return "binary"; 405 } 406 llvm_unreachable("unexpected kind for symbol"); 407 } 408 409 void Merger::dumpExp(unsigned e) const { 410 switch (tensorExps[e].kind) { 411 case kTensor: 412 if (tensorExps[e].tensor == syntheticTensor) 413 llvm::dbgs() << "synthetic_"; 414 else if (tensorExps[e].tensor == outTensor) 415 llvm::dbgs() << "output_"; 416 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 417 break; 418 case kInvariant: 419 llvm::dbgs() << "invariant"; 420 break; 421 case kIndex: 422 llvm::dbgs() << "index_" << tensorExps[e].index; 423 break; 424 case kAbsF: 425 case kCeilF: 426 case kFloorF: 427 case kSqrtF: 428 case kExpm1F: 429 case kLog1pF: 430 case kSinF: 431 case kTanhF: 432 case kNegF: 433 case kNegI: 434 case kTruncF: 435 case kExtF: 436 case kCastFS: 437 case kCastFU: 438 case kCastSF: 439 case kCastUF: 440 case kCastS: 441 case kCastU: 442 case kCastIdx: 443 case kTruncI: 444 case kBitCast: 445 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 446 dumpExp(tensorExps[e].children.e0); 447 break; 448 default: 449 llvm::dbgs() << "("; 450 dumpExp(tensorExps[e].children.e0); 451 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 452 dumpExp(tensorExps[e].children.e1); 453 llvm::dbgs() << ")"; 454 } 455 } 456 457 void Merger::dumpLat(unsigned p) const { 458 llvm::dbgs() << "lat("; 459 dumpBits(latPoints[p].bits); 460 llvm::dbgs() << " :"; 461 dumpBits(latPoints[p].simple); 462 llvm::dbgs() << " : "; 463 dumpExp(latPoints[p].exp); 464 llvm::dbgs() << " )\n"; 465 } 466 467 void Merger::dumpSet(unsigned s) const { 468 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 469 for (unsigned p : latSets[s]) { 470 llvm::dbgs() << " "; 471 dumpLat(p); 472 } 473 llvm::dbgs() << "}\n"; 474 } 475 476 void Merger::dumpBits(const BitVector &bits) const { 477 for (unsigned b = 0, be = bits.size(); b < be; b++) { 478 if (bits[b]) { 479 unsigned t = tensor(b); 480 unsigned i = index(b); 481 llvm::dbgs() << " i_" << t << "_" << i << "_"; 482 switch (dims[t][i]) { 483 case kSparse: 484 llvm::dbgs() << "S"; 485 break; 486 case kDense: 487 llvm::dbgs() << "D"; 488 break; 489 case kSingle: 490 llvm::dbgs() << "T"; 491 break; 492 case kUndef: 493 llvm::dbgs() << "U"; 494 break; 495 } 496 } 497 } 498 } 499 500 #endif // NDEBUG 501 502 //===----------------------------------------------------------------------===// 503 // Builder methods. 504 //===----------------------------------------------------------------------===// 505 506 unsigned Merger::buildLattices(unsigned e, unsigned i) { 507 Kind kind = tensorExps[e].kind; 508 switch (kind) { 509 case kTensor: 510 case kInvariant: 511 case kIndex: { 512 // Either the index is really used in the tensor expression, or it is 513 // set to the undefined index in that dimension. An invariant expression, 514 // a proper index value, and a truly dynamic sparse output tensor are set 515 // to a synthetic tensor with undefined indices only to ensure the 516 // iteration space is not skipped as a result of their contents. 517 unsigned s = addSet(); 518 unsigned t = syntheticTensor; 519 if (kind == kTensor) { 520 t = tensorExps[e].tensor; 521 if (hasSparseOut && t == outTensor) 522 t = syntheticTensor; 523 } 524 latSets[s].push_back(addLat(t, i, e)); 525 return s; 526 } 527 case kAbsF: 528 case kCeilF: 529 case kFloorF: 530 case kSqrtF: 531 case kExpm1F: 532 case kLog1pF: 533 case kSinF: 534 case kTanhF: 535 case kNegF: 536 case kNegI: 537 case kTruncF: 538 case kExtF: 539 case kCastFS: 540 case kCastFU: 541 case kCastSF: 542 case kCastUF: 543 case kCastS: 544 case kCastU: 545 case kCastIdx: 546 case kTruncI: 547 case kBitCast: 548 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 549 // lattice set of the operand through the operator into a new set. 550 // 551 // -y|!y | y | 552 // --+---+---+ 553 // | 0 |-y | 554 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 555 tensorExps[e].val); 556 case kBinaryBranch: 557 // The left or right half of a binary operation which has already 558 // been split into separate operations for each region. 559 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), 560 tensorExps[e].op); 561 case kUnary: 562 // A custom unary operation. 563 // 564 // op y| !y | y | 565 // ----+----------+------------+ 566 // | absent() | present(y) | 567 { 568 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 569 UnaryOp unop = cast<UnaryOp>(tensorExps[e].op); 570 Region &absentRegion = unop.absentRegion(); 571 572 if (absentRegion.empty()) { 573 // Simple mapping over existing values. 574 return mapSet(kind, child0, Value(), unop); 575 } else { 576 // Use a disjunction with `unop` on the left and the absent value as an 577 // invariant on the right. 578 Block &absentBlock = absentRegion.front(); 579 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 580 Value absentVal = absentYield.result(); 581 unsigned rhs = addExp(kInvariant, absentVal); 582 return takeDisj(kind, child0, buildLattices(rhs, i), unop); 583 } 584 } 585 case kMulF: 586 case kMulI: 587 case kAndI: 588 // A multiplicative operation only needs to be performed 589 // for the conjunction of sparse iteration spaces. 590 // 591 // x*y|!y | y | 592 // ---+---+---+ 593 // !x | 0 | 0 | 594 // x | 0 |x*y| 595 return takeConj(kind, // take binary conjunction 596 buildLattices(tensorExps[e].children.e0, i), 597 buildLattices(tensorExps[e].children.e1, i)); 598 case kDivF: 599 case kDivS: 600 case kDivU: 601 // A division is tricky, since 0/0, 0/c, c/0 all have 602 // specific outcomes for floating-point and integers. 603 // Thus, we need to traverse the full iteration space. 604 // 605 // x/y|!y | y | 606 // ---+---+---+ 607 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 608 // x |x/0|x/y| INT: x/0=exception for any x 609 // 610 // TODO: for now we "fixed" this by only accepting x/c cases 611 // during expression building, so that the conjunction 612 // rules applies (viz. x/c = x*(1/c) as far as lattice 613 // construction is concerned). 614 assert(!maybeZero(tensorExps[e].children.e1)); 615 return takeConj(kind, // take binary conjunction 616 buildLattices(tensorExps[e].children.e0, i), 617 buildLattices(tensorExps[e].children.e1, i)); 618 case kAddF: 619 case kAddI: 620 case kSubF: 621 case kSubI: 622 case kOrI: 623 case kXorI: 624 // An additive operation needs to be performed 625 // for the disjunction of sparse iteration spaces. 626 // 627 // x+y|!y | y | x-y|!y | y | 628 // ---+---+---+ ---+---+---+ 629 // !x | 0 | y | !x | 0 |-y | 630 // x | x |x+y| x | x |x-y| 631 return takeDisj(kind, // take binary disjunction 632 buildLattices(tensorExps[e].children.e0, i), 633 buildLattices(tensorExps[e].children.e1, i)); 634 case kShrS: 635 case kShrU: 636 case kShlI: 637 // A shift operation by an invariant amount (viz. tensor expressions 638 // can only occur at the left-hand-side of the operator) can be handled 639 // with the conjuction rule. 640 assert(isInvariant(tensorExps[e].children.e1)); 641 return takeConj(kind, // take binary conjunction 642 buildLattices(tensorExps[e].children.e0, i), 643 buildLattices(tensorExps[e].children.e1, i)); 644 case kBinary: 645 // A custom binary operation. 646 // 647 // x op y| !y | y | 648 // ------+---------+--------------+ 649 // !x | empty | right(y) | 650 // x | left(x) | overlap(x,y) | 651 { 652 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 653 unsigned child1 = buildLattices(tensorExps[e].children.e1, i); 654 BinaryOp binop = cast<BinaryOp>(tensorExps[e].op); 655 Region &leftRegion = binop.leftRegion(); 656 Region &rightRegion = binop.rightRegion(); 657 // Left Region. 658 Operation *leftYield = nullptr; 659 if (!leftRegion.empty()) { 660 Block &leftBlock = leftRegion.front(); 661 leftYield = leftBlock.getTerminator(); 662 } 663 // Right Region. 664 Operation *rightYield = nullptr; 665 if (!rightRegion.empty()) { 666 Block &rightBlock = rightRegion.front(); 667 rightYield = rightBlock.getTerminator(); 668 } 669 bool includeLeft = binop.left_identity() || !leftRegion.empty(); 670 bool includeRight = binop.right_identity() || !rightRegion.empty(); 671 return takeCombi(kBinary, child0, child1, binop, includeLeft, 672 kBinaryBranch, leftYield, includeRight, kBinaryBranch, 673 rightYield); 674 } 675 } 676 llvm_unreachable("unexpected expression kind"); 677 } 678 679 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 680 Operation *yield = op.region().front().getTerminator(); 681 return buildTensorExp(op, yield->getOperand(0)); 682 } 683 684 /// Only returns false if we are certain this is a nonzero. 685 bool Merger::maybeZero(unsigned e) const { 686 if (tensorExps[e].kind == kInvariant) { 687 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 688 return c.value() == 0; 689 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 690 return c.value().isZero(); 691 } 692 return true; 693 } 694 695 bool Merger::isInvariant(unsigned e) const { 696 return tensorExps[e].kind == kInvariant; 697 } 698 699 Type Merger::inferType(unsigned e, Value src) { 700 // Obtain the destination type from the cast node. 701 Type dtp = tensorExps[e].val.getType(); 702 // Inspect source type. For vector types, apply the same 703 // vectorization to the destination type. 704 if (auto vtp = src.getType().dyn_cast<VectorType>()) 705 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 706 return dtp; 707 } 708 709 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 710 if (auto arg = v.dyn_cast<BlockArgument>()) { 711 unsigned argN = arg.getArgNumber(); 712 // Any argument of the generic op that is not marked as a scalar 713 // argument is considered a tensor, indexed by the implicit loop 714 // bounds. This includes rank-0 tensor arguments. 715 if (arg.getOwner()->getParentOp() == op) { 716 OpOperand *t = op.getInputAndOutputOperands()[argN]; 717 if (!op.isScalar(t)) 718 return addExp(kTensor, argN); 719 v = t->get(); // get scalar value 720 } 721 // Any other argument (marked as scalar argument for the generic op 722 // or belonging to an enveloping op) is considered invariant. 723 return addExp(kInvariant, v); 724 } 725 // Something defined outside is invariant. 726 Operation *def = v.getDefiningOp(); 727 if (def->getBlock() != &op.region().front()) 728 return addExp(kInvariant, v); 729 // Construct index operations. 730 if (def->getNumOperands() == 0) { 731 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 732 return addExp(kIndex, indexOp.dim()); 733 } 734 // Construct unary operations if subexpression can be built. 735 if (def->getNumOperands() == 1) { 736 auto x = buildTensorExp(op, def->getOperand(0)); 737 if (x.hasValue()) { 738 unsigned e = x.getValue(); 739 if (isa<math::AbsOp>(def)) 740 return addExp(kAbsF, e); 741 if (isa<math::CeilOp>(def)) 742 return addExp(kCeilF, e); 743 if (isa<math::FloorOp>(def)) 744 return addExp(kFloorF, e); 745 if (isa<math::SqrtOp>(def)) 746 return addExp(kSqrtF, e); 747 if (isa<math::ExpM1Op>(def)) 748 return addExp(kExpm1F, e); 749 if (isa<math::Log1pOp>(def)) 750 return addExp(kLog1pF, e); 751 if (isa<math::SinOp>(def)) 752 return addExp(kSinF, e); 753 if (isa<math::TanhOp>(def)) 754 return addExp(kTanhF, e); 755 if (isa<arith::NegFOp>(def)) 756 return addExp(kNegF, e); // no negi in std 757 if (isa<arith::TruncFOp>(def)) 758 return addExp(kTruncF, e, v); 759 if (isa<arith::ExtFOp>(def)) 760 return addExp(kExtF, e, v); 761 if (isa<arith::FPToSIOp>(def)) 762 return addExp(kCastFS, e, v); 763 if (isa<arith::FPToUIOp>(def)) 764 return addExp(kCastFU, e, v); 765 if (isa<arith::SIToFPOp>(def)) 766 return addExp(kCastSF, e, v); 767 if (isa<arith::UIToFPOp>(def)) 768 return addExp(kCastUF, e, v); 769 if (isa<arith::ExtSIOp>(def)) 770 return addExp(kCastS, e, v); 771 if (isa<arith::ExtUIOp>(def)) 772 return addExp(kCastU, e, v); 773 if (isa<arith::IndexCastOp>(def)) 774 return addExp(kCastIdx, e, v); 775 if (isa<arith::TruncIOp>(def)) 776 return addExp(kTruncI, e, v); 777 if (isa<arith::BitcastOp>(def)) 778 return addExp(kBitCast, e, v); 779 if (isa<sparse_tensor::UnaryOp>(def)) 780 return addExp(kUnary, e, Value(), def); 781 } 782 } 783 // Construct binary operations if subexpressions can be built. 784 // See buildLattices() for an explanation of rejecting certain 785 // division and shift operations 786 if (def->getNumOperands() == 2) { 787 auto x = buildTensorExp(op, def->getOperand(0)); 788 auto y = buildTensorExp(op, def->getOperand(1)); 789 if (x.hasValue() && y.hasValue()) { 790 unsigned e0 = x.getValue(); 791 unsigned e1 = y.getValue(); 792 if (isa<arith::MulFOp>(def)) 793 return addExp(kMulF, e0, e1); 794 if (isa<arith::MulIOp>(def)) 795 return addExp(kMulI, e0, e1); 796 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 797 return addExp(kDivF, e0, e1); 798 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 799 return addExp(kDivS, e0, e1); 800 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 801 return addExp(kDivU, e0, e1); 802 if (isa<arith::AddFOp>(def)) 803 return addExp(kAddF, e0, e1); 804 if (isa<arith::AddIOp>(def)) 805 return addExp(kAddI, e0, e1); 806 if (isa<arith::SubFOp>(def)) 807 return addExp(kSubF, e0, e1); 808 if (isa<arith::SubIOp>(def)) 809 return addExp(kSubI, e0, e1); 810 if (isa<arith::AndIOp>(def)) 811 return addExp(kAndI, e0, e1); 812 if (isa<arith::OrIOp>(def)) 813 return addExp(kOrI, e0, e1); 814 if (isa<arith::XOrIOp>(def)) 815 return addExp(kXorI, e0, e1); 816 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 817 return addExp(kShrS, e0, e1); 818 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 819 return addExp(kShrU, e0, e1); 820 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 821 return addExp(kShlI, e0, e1); 822 if (isa<sparse_tensor::BinaryOp>(def)) 823 return addExp(kBinary, e0, e1, Value(), def); 824 } 825 } 826 // Cannot build. 827 return None; 828 } 829 830 static Value insertYieldOp(PatternRewriter &rewriter, Location loc, 831 Region ®ion, ValueRange vals) { 832 // Make a clone of overlap region. 833 Region tmpRegion; 834 BlockAndValueMapping mapper; 835 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 836 Block &clonedBlock = tmpRegion.front(); 837 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 838 // Merge cloned block and return yield value. 839 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 840 rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); 841 Value val = clonedYield.result(); 842 rewriter.eraseOp(clonedYield); 843 rewriter.eraseOp(placeholder); 844 return val; 845 } 846 847 static Value buildUnaryPresent(PatternRewriter &rewriter, Location loc, 848 Operation *op, Value v0) { 849 if (!v0) 850 // Empty input value must be propagated. 851 return Value(); 852 UnaryOp unop = cast<UnaryOp>(op); 853 Region &presentRegion = unop.presentRegion(); 854 if (presentRegion.empty()) 855 // Uninitialized Value() will be interpreted as missing data in the 856 // output. 857 return Value(); 858 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 859 } 860 861 static Value buildBinaryOverlap(PatternRewriter &rewriter, Location loc, 862 Operation *op, Value v0, Value v1) { 863 if (!v0 || !v1) 864 // Empty input values must be propagated. 865 return Value(); 866 BinaryOp binop = cast<BinaryOp>(op); 867 Region &overlapRegion = binop.overlapRegion(); 868 if (overlapRegion.empty()) 869 // Uninitialized Value() will be interpreted as missing data in the 870 // output. 871 return Value(); 872 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 873 } 874 875 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, 876 Value v0, Value v1) { 877 switch (tensorExps[e].kind) { 878 case kTensor: 879 case kInvariant: 880 case kIndex: 881 llvm_unreachable("unexpected non-op"); 882 // Unary ops. 883 case kAbsF: 884 return rewriter.create<math::AbsOp>(loc, v0); 885 case kCeilF: 886 return rewriter.create<math::CeilOp>(loc, v0); 887 case kFloorF: 888 return rewriter.create<math::FloorOp>(loc, v0); 889 case kSqrtF: 890 return rewriter.create<math::SqrtOp>(loc, v0); 891 case kExpm1F: 892 return rewriter.create<math::ExpM1Op>(loc, v0); 893 case kLog1pF: 894 return rewriter.create<math::Log1pOp>(loc, v0); 895 case kSinF: 896 return rewriter.create<math::SinOp>(loc, v0); 897 case kTanhF: 898 return rewriter.create<math::TanhOp>(loc, v0); 899 case kNegF: 900 return rewriter.create<arith::NegFOp>(loc, v0); 901 case kNegI: // no negi in std 902 return rewriter.create<arith::SubIOp>( 903 loc, 904 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 905 rewriter.getZeroAttr(v0.getType())), 906 v0); 907 case kTruncF: 908 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 909 case kExtF: 910 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 911 case kCastFS: 912 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 913 case kCastFU: 914 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 915 case kCastSF: 916 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 917 case kCastUF: 918 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 919 case kCastS: 920 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 921 case kCastU: 922 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 923 case kCastIdx: 924 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 925 case kTruncI: 926 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 927 case kBitCast: 928 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 929 // Binary ops. 930 case kMulF: 931 return rewriter.create<arith::MulFOp>(loc, v0, v1); 932 case kMulI: 933 return rewriter.create<arith::MulIOp>(loc, v0, v1); 934 case kDivF: 935 return rewriter.create<arith::DivFOp>(loc, v0, v1); 936 case kDivS: 937 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 938 case kDivU: 939 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 940 case kAddF: 941 return rewriter.create<arith::AddFOp>(loc, v0, v1); 942 case kAddI: 943 return rewriter.create<arith::AddIOp>(loc, v0, v1); 944 case kSubF: 945 return rewriter.create<arith::SubFOp>(loc, v0, v1); 946 case kSubI: 947 return rewriter.create<arith::SubIOp>(loc, v0, v1); 948 case kAndI: 949 return rewriter.create<arith::AndIOp>(loc, v0, v1); 950 case kOrI: 951 return rewriter.create<arith::OrIOp>(loc, v0, v1); 952 case kXorI: 953 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 954 case kShrS: 955 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 956 case kShrU: 957 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 958 case kShlI: 959 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 960 // Semiring ops with custom logic. 961 case kBinaryBranch: 962 return insertYieldOp(rewriter, loc, 963 *tensorExps[e].op->getBlock()->getParent(), {v0}); 964 case kUnary: 965 return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); 966 case kBinary: 967 return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); 968 } 969 llvm_unreachable("unexpected expression kind in build"); 970 } 971 972 } // namespace sparse_tensor 973 } // namespace mlir 974