1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Math/IR/Math.h" 12 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 13 14 #include "mlir/IR/Operation.h" 15 #include "llvm/Support/Debug.h" 16 17 namespace mlir { 18 namespace sparse_tensor { 19 20 //===----------------------------------------------------------------------===// 21 // Constructors. 22 //===----------------------------------------------------------------------===// 23 24 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) 25 : kind(k), val(v), op(o) { 26 switch (kind) { 27 case kTensor: 28 assert(x != -1u && y == -1u && !v && !o); 29 tensor = x; 30 break; 31 case kInvariant: 32 assert(x == -1u && y == -1u && v && !o); 33 break; 34 case kIndex: 35 assert(x != -1u && y == -1u && !v && !o); 36 index = x; 37 break; 38 case kAbsF: 39 case kCeilF: 40 case kFloorF: 41 case kSqrtF: 42 case kExpm1F: 43 case kLog1pF: 44 case kSinF: 45 case kTanhF: 46 case kNegF: 47 case kNegI: 48 assert(x != -1u && y == -1u && !v && !o); 49 children.e0 = x; 50 children.e1 = y; 51 break; 52 case kTruncF: 53 case kExtF: 54 case kCastFS: 55 case kCastFU: 56 case kCastSF: 57 case kCastUF: 58 case kCastS: 59 case kCastU: 60 case kCastIdx: 61 case kTruncI: 62 case kBitCast: 63 assert(x != -1u && y == -1u && v && !o); 64 children.e0 = x; 65 children.e1 = y; 66 break; 67 case kBinaryBranch: 68 assert(x != -1u && y == -1u && !v && o); 69 children.e0 = x; 70 children.e1 = y; 71 break; 72 case kUnary: 73 // No assertion on y can be made, as the branching paths involve both 74 // a unary (mapSet) and binary (takeDisj) pathway. 75 assert(x != -1u && !v && o); 76 children.e0 = x; 77 children.e1 = y; 78 break; 79 case kBinary: 80 assert(x != -1u && y != -1u && !v && o); 81 children.e0 = x; 82 children.e1 = y; 83 break; 84 default: 85 assert(x != -1u && y != -1u && !v && !o); 86 children.e0 = x; 87 children.e1 = y; 88 break; 89 } 90 } 91 92 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 93 : bits(n, false), simple(), exp(e) { 94 bits.set(b); 95 } 96 97 LatPoint::LatPoint(const BitVector &b, unsigned e) 98 : bits(b), simple(), exp(e) {} 99 100 //===----------------------------------------------------------------------===// 101 // Lattice methods. 102 //===----------------------------------------------------------------------===// 103 104 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, 105 Operation *op) { 106 unsigned e = tensorExps.size(); 107 tensorExps.push_back(TensorExp(k, e0, e1, v, op)); 108 return e; 109 } 110 111 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 112 assert(t < numTensors && i < numLoops); 113 unsigned p = latPoints.size(); 114 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 115 return p; 116 } 117 118 unsigned Merger::addSet() { 119 unsigned s = latSets.size(); 120 latSets.emplace_back(SmallVector<unsigned, 16>()); 121 return s; 122 } 123 124 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, 125 Operation *op) { 126 unsigned p = latPoints.size(); 127 BitVector nb = BitVector(latPoints[p0].bits); 128 nb |= latPoints[p1].bits; 129 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); 130 latPoints.push_back(LatPoint(nb, e)); 131 return p; 132 } 133 134 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 135 unsigned s = addSet(); 136 for (unsigned p0 : latSets[s0]) 137 for (unsigned p1 : latSets[s1]) 138 latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); 139 return s; 140 } 141 142 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 143 unsigned s = takeConj(kind, s0, s1, op); 144 // Followed by all in s0. 145 for (unsigned p : latSets[s0]) 146 latSets[s].push_back(p); 147 // Map binary 0-y to unary -y. 148 // TODO: move this if-else logic into buildLattices 149 if (kind == kSubF) 150 s1 = mapSet(kNegF, s1); 151 else if (kind == kSubI) 152 s1 = mapSet(kNegI, s1); 153 // Followed by all in s1. 154 for (unsigned p : latSets[s1]) 155 latSets[s].push_back(p); 156 return s; 157 } 158 159 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, 160 bool includeLeft, Kind ltrans, Operation *opleft, 161 bool includeRight, Kind rtrans, Operation *opright) { 162 unsigned s = takeConj(kind, s0, s1, orig); 163 // Left Region. 164 if (includeLeft) { 165 if (opleft) 166 s0 = mapSet(ltrans, s0, Value(), opleft); 167 for (unsigned p : latSets[s0]) 168 latSets[s].push_back(p); 169 } 170 // Right Region. 171 if (includeRight) { 172 if (opright) 173 s1 = mapSet(rtrans, s1, Value(), opright); 174 for (unsigned p : latSets[s1]) 175 latSets[s].push_back(p); 176 } 177 return s; 178 } 179 180 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { 181 assert(kAbsF <= kind && kind <= kUnary); 182 unsigned s = addSet(); 183 for (unsigned p : latSets[s0]) { 184 unsigned e = addExp(kind, latPoints[p].exp, v, op); 185 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 186 latSets[s].push_back(latPoints.size() - 1); 187 } 188 return s; 189 } 190 191 unsigned Merger::optimizeSet(unsigned s0) { 192 unsigned s = addSet(); 193 assert(!latSets[s0].empty()); 194 unsigned p0 = latSets[s0][0]; 195 for (unsigned p1 : latSets[s0]) { 196 bool add = true; 197 if (p0 != p1) { 198 // Is this a straightforward copy? 199 unsigned e = latPoints[p1].exp; 200 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 201 continue; 202 // Conjunction already covered? 203 for (unsigned p2 : latSets[s]) { 204 assert(!latGT(p1, p2)); // Lj => Li would be bad 205 if (onlyDenseDiff(p2, p1)) { 206 add = false; 207 break; 208 } 209 } 210 assert(!add || latGT(p0, p1)); 211 } 212 if (add) 213 latSets[s].push_back(p1); 214 } 215 for (unsigned p : latSets[s]) 216 latPoints[p].simple = simplifyCond(s, p); 217 return s; 218 } 219 220 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 221 // First determine if this lattice point is a *singleton*, i.e., 222 // the last point in a lattice, no other is less than this one. 223 bool isSingleton = true; 224 for (unsigned p1 : latSets[s0]) { 225 if (p0 != p1 && latGT(p0, p1)) { 226 isSingleton = false; 227 break; 228 } 229 } 230 // Now apply the two basic rules. 231 BitVector simple = latPoints[p0].bits; 232 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 233 for (unsigned b = 0, be = simple.size(); b < be; b++) { 234 if (simple[b] && !isDim(b, kSparse)) { 235 if (reset) 236 simple.reset(b); 237 reset = true; 238 } 239 } 240 return simple; 241 } 242 243 bool Merger::latGT(unsigned i, unsigned j) const { 244 const BitVector &bitsi = latPoints[i].bits; 245 const BitVector &bitsj = latPoints[j].bits; 246 assert(bitsi.size() == bitsj.size()); 247 if (bitsi.count() > bitsj.count()) { 248 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 249 if (bitsj[b] && !bitsi[b]) 250 return false; 251 return true; 252 } 253 return false; 254 } 255 256 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 257 BitVector tmp = latPoints[j].bits; 258 tmp ^= latPoints[i].bits; 259 return !hasAnyDimOf(tmp, kSparse); 260 } 261 262 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { 263 for (unsigned b = 0, be = bits.size(); b < be; b++) 264 if (bits[b] && isDim(b, d)) 265 return true; 266 return false; 267 } 268 269 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 270 switch (tensorExps[e].kind) { 271 case kTensor: 272 return tensorExps[e].tensor == t; 273 case kAbsF: 274 case kCeilF: 275 case kFloorF: 276 case kSqrtF: 277 case kExpm1F: 278 case kLog1pF: 279 case kSinF: 280 case kTanhF: 281 case kNegF: 282 case kNegI: 283 case kTruncF: 284 case kExtF: 285 case kCastFS: 286 case kCastFU: 287 case kCastSF: 288 case kCastUF: 289 case kCastS: 290 case kCastU: 291 case kCastIdx: 292 case kTruncI: 293 case kBitCast: 294 return isSingleCondition(t, tensorExps[e].children.e0); 295 case kDivF: // note: x / c only 296 case kDivS: 297 case kDivU: 298 assert(!maybeZero(tensorExps[e].children.e1)); 299 return isSingleCondition(t, tensorExps[e].children.e0); 300 case kShrS: // note: x >> inv only 301 case kShrU: 302 case kShlI: 303 assert(isInvariant(tensorExps[e].children.e1)); 304 return isSingleCondition(t, tensorExps[e].children.e0); 305 case kMulF: 306 case kMulI: 307 case kAndI: 308 if (isSingleCondition(t, tensorExps[e].children.e0)) 309 return isSingleCondition(t, tensorExps[e].children.e1) || 310 isInvariant(tensorExps[e].children.e1); 311 if (isSingleCondition(t, tensorExps[e].children.e1)) 312 return isInvariant(tensorExps[e].children.e0); 313 return false; 314 case kAddF: 315 case kAddI: 316 return isSingleCondition(t, tensorExps[e].children.e0) && 317 isSingleCondition(t, tensorExps[e].children.e1); 318 default: 319 return false; 320 } 321 } 322 323 #ifndef NDEBUG 324 325 //===----------------------------------------------------------------------===// 326 // Print methods (for debugging). 327 //===----------------------------------------------------------------------===// 328 329 static const char *kindToOpSymbol(Kind kind) { 330 switch (kind) { 331 case kTensor: 332 return "tensor"; 333 case kInvariant: 334 return "invariant"; 335 case kIndex: 336 return "index"; 337 case kAbsF: 338 return "abs"; 339 case kCeilF: 340 return "ceil"; 341 case kFloorF: 342 return "floor"; 343 case kSqrtF: 344 return "sqrt"; 345 case kExpm1F: 346 return "expm1"; 347 case kLog1pF: 348 return "log1p"; 349 case kSinF: 350 return "sin"; 351 case kTanhF: 352 return "tanh"; 353 case kNegF: 354 return "-"; 355 case kNegI: 356 return "-"; 357 case kTruncF: 358 case kExtF: 359 case kCastFS: 360 case kCastFU: 361 case kCastSF: 362 case kCastUF: 363 case kCastS: 364 case kCastU: 365 case kCastIdx: 366 case kTruncI: 367 case kBitCast: 368 return "cast"; 369 case kBinaryBranch: 370 return "binary_branch"; 371 case kUnary: 372 return "unary"; 373 case kMulF: 374 return "*"; 375 case kMulI: 376 return "*"; 377 case kDivF: 378 return "/"; 379 case kDivS: 380 return "/"; 381 case kDivU: 382 return "/"; 383 case kAddF: 384 return "+"; 385 case kAddI: 386 return "+"; 387 case kSubF: 388 return "-"; 389 case kSubI: 390 return "-"; 391 case kAndI: 392 return "&"; 393 case kOrI: 394 return "|"; 395 case kXorI: 396 return "^"; 397 case kShrS: 398 return "a>>"; 399 case kShrU: 400 return ">>"; 401 case kShlI: 402 return "<<"; 403 case kBinary: 404 return "binary"; 405 } 406 llvm_unreachable("unexpected kind for symbol"); 407 } 408 409 void Merger::dumpExp(unsigned e) const { 410 switch (tensorExps[e].kind) { 411 case kTensor: 412 if (tensorExps[e].tensor == syntheticTensor) 413 llvm::dbgs() << "synthetic_"; 414 else if (tensorExps[e].tensor == outTensor) 415 llvm::dbgs() << "output_"; 416 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 417 break; 418 case kInvariant: 419 llvm::dbgs() << "invariant"; 420 break; 421 case kIndex: 422 llvm::dbgs() << "index_" << tensorExps[e].index; 423 break; 424 case kAbsF: 425 case kCeilF: 426 case kFloorF: 427 case kSqrtF: 428 case kExpm1F: 429 case kLog1pF: 430 case kSinF: 431 case kTanhF: 432 case kNegF: 433 case kNegI: 434 case kTruncF: 435 case kExtF: 436 case kCastFS: 437 case kCastFU: 438 case kCastSF: 439 case kCastUF: 440 case kCastS: 441 case kCastU: 442 case kCastIdx: 443 case kTruncI: 444 case kBitCast: 445 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 446 dumpExp(tensorExps[e].children.e0); 447 break; 448 default: 449 llvm::dbgs() << "("; 450 dumpExp(tensorExps[e].children.e0); 451 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 452 dumpExp(tensorExps[e].children.e1); 453 llvm::dbgs() << ")"; 454 } 455 } 456 457 void Merger::dumpLat(unsigned p) const { 458 llvm::dbgs() << "lat("; 459 dumpBits(latPoints[p].bits); 460 llvm::dbgs() << " :"; 461 dumpBits(latPoints[p].simple); 462 llvm::dbgs() << " : "; 463 dumpExp(latPoints[p].exp); 464 llvm::dbgs() << " )\n"; 465 } 466 467 void Merger::dumpSet(unsigned s) const { 468 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 469 for (unsigned p : latSets[s]) { 470 llvm::dbgs() << " "; 471 dumpLat(p); 472 } 473 llvm::dbgs() << "}\n"; 474 } 475 476 void Merger::dumpBits(const BitVector &bits) const { 477 for (unsigned b = 0, be = bits.size(); b < be; b++) { 478 if (bits[b]) { 479 unsigned t = tensor(b); 480 unsigned i = index(b); 481 llvm::dbgs() << " i_" << t << "_" << i << "_"; 482 switch (dims[t][i]) { 483 case kSparse: 484 llvm::dbgs() << "S"; 485 break; 486 case kDense: 487 llvm::dbgs() << "D"; 488 break; 489 case kSingle: 490 llvm::dbgs() << "T"; 491 break; 492 case kUndef: 493 llvm::dbgs() << "U"; 494 break; 495 } 496 } 497 } 498 } 499 500 #endif // NDEBUG 501 502 //===----------------------------------------------------------------------===// 503 // Builder methods. 504 //===----------------------------------------------------------------------===// 505 506 unsigned Merger::buildLattices(unsigned e, unsigned i) { 507 Kind kind = tensorExps[e].kind; 508 switch (kind) { 509 case kTensor: 510 case kInvariant: 511 case kIndex: { 512 // Either the index is really used in the tensor expression, or it is 513 // set to the undefined index in that dimension. An invariant expression, 514 // a proper index value, and a truly dynamic sparse output tensor are set 515 // to a synthetic tensor with undefined indices only to ensure the 516 // iteration space is not skipped as a result of their contents. 517 unsigned s = addSet(); 518 unsigned t = syntheticTensor; 519 if (kind == kTensor) { 520 t = tensorExps[e].tensor; 521 if (hasSparseOut && t == outTensor) 522 t = syntheticTensor; 523 } 524 latSets[s].push_back(addLat(t, i, e)); 525 return s; 526 } 527 case kAbsF: 528 case kCeilF: 529 case kFloorF: 530 case kSqrtF: 531 case kExpm1F: 532 case kLog1pF: 533 case kSinF: 534 case kTanhF: 535 case kNegF: 536 case kNegI: 537 case kTruncF: 538 case kExtF: 539 case kCastFS: 540 case kCastFU: 541 case kCastSF: 542 case kCastUF: 543 case kCastS: 544 case kCastU: 545 case kCastIdx: 546 case kTruncI: 547 case kBitCast: 548 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 549 // lattice set of the operand through the operator into a new set. 550 // 551 // -y|!y | y | 552 // --+---+---+ 553 // | 0 |-y | 554 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 555 tensorExps[e].val); 556 case kBinaryBranch: 557 // The left or right half of a binary operation which has already 558 // been split into separate operations for each region. 559 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), 560 tensorExps[e].op); 561 case kUnary: 562 // A custom unary operation. 563 // 564 // op y| !y | y | 565 // ----+----------+------------+ 566 // | absent() | present(y) | 567 { 568 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 569 UnaryOp unop = cast<UnaryOp>(tensorExps[e].op); 570 Region &absentRegion = unop.absentRegion(); 571 572 if (absentRegion.empty()) { 573 // Simple mapping over existing values. 574 return mapSet(kind, child0, Value(), unop); 575 } // Use a disjunction with `unop` on the left and the absent value as an 576 // invariant on the right. 577 Block &absentBlock = absentRegion.front(); 578 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 579 Value absentVal = absentYield.result(); 580 unsigned rhs = addExp(kInvariant, absentVal); 581 return takeDisj(kind, child0, buildLattices(rhs, i), unop); 582 } 583 case kMulF: 584 case kMulI: 585 case kAndI: 586 // A multiplicative operation only needs to be performed 587 // for the conjunction of sparse iteration spaces. 588 // 589 // x*y|!y | y | 590 // ---+---+---+ 591 // !x | 0 | 0 | 592 // x | 0 |x*y| 593 return takeConj(kind, // take binary conjunction 594 buildLattices(tensorExps[e].children.e0, i), 595 buildLattices(tensorExps[e].children.e1, i)); 596 case kDivF: 597 case kDivS: 598 case kDivU: 599 // A division is tricky, since 0/0, 0/c, c/0 all have 600 // specific outcomes for floating-point and integers. 601 // Thus, we need to traverse the full iteration space. 602 // 603 // x/y|!y | y | 604 // ---+---+---+ 605 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 606 // x |x/0|x/y| INT: x/0=exception for any x 607 // 608 // TODO: for now we "fixed" this by only accepting x/c cases 609 // during expression building, so that the conjunction 610 // rules applies (viz. x/c = x*(1/c) as far as lattice 611 // construction is concerned). 612 assert(!maybeZero(tensorExps[e].children.e1)); 613 return takeConj(kind, // take binary conjunction 614 buildLattices(tensorExps[e].children.e0, i), 615 buildLattices(tensorExps[e].children.e1, i)); 616 case kAddF: 617 case kAddI: 618 case kSubF: 619 case kSubI: 620 case kOrI: 621 case kXorI: 622 // An additive operation needs to be performed 623 // for the disjunction of sparse iteration spaces. 624 // 625 // x+y|!y | y | x-y|!y | y | 626 // ---+---+---+ ---+---+---+ 627 // !x | 0 | y | !x | 0 |-y | 628 // x | x |x+y| x | x |x-y| 629 return takeDisj(kind, // take binary disjunction 630 buildLattices(tensorExps[e].children.e0, i), 631 buildLattices(tensorExps[e].children.e1, i)); 632 case kShrS: 633 case kShrU: 634 case kShlI: 635 // A shift operation by an invariant amount (viz. tensor expressions 636 // can only occur at the left-hand-side of the operator) can be handled 637 // with the conjuction rule. 638 assert(isInvariant(tensorExps[e].children.e1)); 639 return takeConj(kind, // take binary conjunction 640 buildLattices(tensorExps[e].children.e0, i), 641 buildLattices(tensorExps[e].children.e1, i)); 642 case kBinary: 643 // A custom binary operation. 644 // 645 // x op y| !y | y | 646 // ------+---------+--------------+ 647 // !x | empty | right(y) | 648 // x | left(x) | overlap(x,y) | 649 { 650 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 651 unsigned child1 = buildLattices(tensorExps[e].children.e1, i); 652 BinaryOp binop = cast<BinaryOp>(tensorExps[e].op); 653 Region &leftRegion = binop.leftRegion(); 654 Region &rightRegion = binop.rightRegion(); 655 // Left Region. 656 Operation *leftYield = nullptr; 657 if (!leftRegion.empty()) { 658 Block &leftBlock = leftRegion.front(); 659 leftYield = leftBlock.getTerminator(); 660 } 661 // Right Region. 662 Operation *rightYield = nullptr; 663 if (!rightRegion.empty()) { 664 Block &rightBlock = rightRegion.front(); 665 rightYield = rightBlock.getTerminator(); 666 } 667 bool includeLeft = binop.left_identity() || !leftRegion.empty(); 668 bool includeRight = binop.right_identity() || !rightRegion.empty(); 669 return takeCombi(kBinary, child0, child1, binop, includeLeft, 670 kBinaryBranch, leftYield, includeRight, kBinaryBranch, 671 rightYield); 672 } 673 } 674 llvm_unreachable("unexpected expression kind"); 675 } 676 677 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 678 Operation *yield = op.region().front().getTerminator(); 679 return buildTensorExp(op, yield->getOperand(0)); 680 } 681 682 /// Only returns false if we are certain this is a nonzero. 683 bool Merger::maybeZero(unsigned e) const { 684 if (tensorExps[e].kind == kInvariant) { 685 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 686 return c.value() == 0; 687 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 688 return c.value().isZero(); 689 } 690 return true; 691 } 692 693 bool Merger::isInvariant(unsigned e) const { 694 return tensorExps[e].kind == kInvariant; 695 } 696 697 Type Merger::inferType(unsigned e, Value src) { 698 // Obtain the destination type from the cast node. 699 Type dtp = tensorExps[e].val.getType(); 700 // Inspect source type. For vector types, apply the same 701 // vectorization to the destination type. 702 if (auto vtp = src.getType().dyn_cast<VectorType>()) 703 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 704 return dtp; 705 } 706 707 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 708 if (auto arg = v.dyn_cast<BlockArgument>()) { 709 unsigned argN = arg.getArgNumber(); 710 // Any argument of the generic op that is not marked as a scalar 711 // argument is considered a tensor, indexed by the implicit loop 712 // bounds. This includes rank-0 tensor arguments. 713 if (arg.getOwner()->getParentOp() == op) { 714 OpOperand *t = op.getInputAndOutputOperands()[argN]; 715 if (!op.isScalar(t)) 716 return addExp(kTensor, argN); 717 v = t->get(); // get scalar value 718 } 719 // Any other argument (marked as scalar argument for the generic op 720 // or belonging to an enveloping op) is considered invariant. 721 return addExp(kInvariant, v); 722 } 723 // Something defined outside is invariant. 724 Operation *def = v.getDefiningOp(); 725 if (def->getBlock() != &op.region().front()) 726 return addExp(kInvariant, v); 727 // Construct index operations. 728 if (def->getNumOperands() == 0) { 729 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 730 return addExp(kIndex, indexOp.dim()); 731 } 732 // Construct unary operations if subexpression can be built. 733 if (def->getNumOperands() == 1) { 734 auto x = buildTensorExp(op, def->getOperand(0)); 735 if (x.hasValue()) { 736 unsigned e = x.getValue(); 737 if (isa<math::AbsOp>(def)) 738 return addExp(kAbsF, e); 739 if (isa<math::CeilOp>(def)) 740 return addExp(kCeilF, e); 741 if (isa<math::FloorOp>(def)) 742 return addExp(kFloorF, e); 743 if (isa<math::SqrtOp>(def)) 744 return addExp(kSqrtF, e); 745 if (isa<math::ExpM1Op>(def)) 746 return addExp(kExpm1F, e); 747 if (isa<math::Log1pOp>(def)) 748 return addExp(kLog1pF, e); 749 if (isa<math::SinOp>(def)) 750 return addExp(kSinF, e); 751 if (isa<math::TanhOp>(def)) 752 return addExp(kTanhF, e); 753 if (isa<arith::NegFOp>(def)) 754 return addExp(kNegF, e); // no negi in std 755 if (isa<arith::TruncFOp>(def)) 756 return addExp(kTruncF, e, v); 757 if (isa<arith::ExtFOp>(def)) 758 return addExp(kExtF, e, v); 759 if (isa<arith::FPToSIOp>(def)) 760 return addExp(kCastFS, e, v); 761 if (isa<arith::FPToUIOp>(def)) 762 return addExp(kCastFU, e, v); 763 if (isa<arith::SIToFPOp>(def)) 764 return addExp(kCastSF, e, v); 765 if (isa<arith::UIToFPOp>(def)) 766 return addExp(kCastUF, e, v); 767 if (isa<arith::ExtSIOp>(def)) 768 return addExp(kCastS, e, v); 769 if (isa<arith::ExtUIOp>(def)) 770 return addExp(kCastU, e, v); 771 if (isa<arith::IndexCastOp>(def)) 772 return addExp(kCastIdx, e, v); 773 if (isa<arith::TruncIOp>(def)) 774 return addExp(kTruncI, e, v); 775 if (isa<arith::BitcastOp>(def)) 776 return addExp(kBitCast, e, v); 777 if (isa<sparse_tensor::UnaryOp>(def)) 778 return addExp(kUnary, e, Value(), def); 779 } 780 } 781 // Construct binary operations if subexpressions can be built. 782 // See buildLattices() for an explanation of rejecting certain 783 // division and shift operations 784 if (def->getNumOperands() == 2) { 785 auto x = buildTensorExp(op, def->getOperand(0)); 786 auto y = buildTensorExp(op, def->getOperand(1)); 787 if (x.hasValue() && y.hasValue()) { 788 unsigned e0 = x.getValue(); 789 unsigned e1 = y.getValue(); 790 if (isa<arith::MulFOp>(def)) 791 return addExp(kMulF, e0, e1); 792 if (isa<arith::MulIOp>(def)) 793 return addExp(kMulI, e0, e1); 794 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 795 return addExp(kDivF, e0, e1); 796 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 797 return addExp(kDivS, e0, e1); 798 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 799 return addExp(kDivU, e0, e1); 800 if (isa<arith::AddFOp>(def)) 801 return addExp(kAddF, e0, e1); 802 if (isa<arith::AddIOp>(def)) 803 return addExp(kAddI, e0, e1); 804 if (isa<arith::SubFOp>(def)) 805 return addExp(kSubF, e0, e1); 806 if (isa<arith::SubIOp>(def)) 807 return addExp(kSubI, e0, e1); 808 if (isa<arith::AndIOp>(def)) 809 return addExp(kAndI, e0, e1); 810 if (isa<arith::OrIOp>(def)) 811 return addExp(kOrI, e0, e1); 812 if (isa<arith::XOrIOp>(def)) 813 return addExp(kXorI, e0, e1); 814 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 815 return addExp(kShrS, e0, e1); 816 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 817 return addExp(kShrU, e0, e1); 818 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 819 return addExp(kShlI, e0, e1); 820 if (isa<sparse_tensor::BinaryOp>(def)) 821 return addExp(kBinary, e0, e1, Value(), def); 822 } 823 } 824 // Cannot build. 825 return None; 826 } 827 828 static Value insertYieldOp(PatternRewriter &rewriter, Location loc, 829 Region ®ion, ValueRange vals) { 830 // Make a clone of overlap region. 831 Region tmpRegion; 832 BlockAndValueMapping mapper; 833 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 834 Block &clonedBlock = tmpRegion.front(); 835 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 836 // Merge cloned block and return yield value. 837 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 838 rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); 839 Value val = clonedYield.result(); 840 rewriter.eraseOp(clonedYield); 841 rewriter.eraseOp(placeholder); 842 return val; 843 } 844 845 static Value buildUnaryPresent(PatternRewriter &rewriter, Location loc, 846 Operation *op, Value v0) { 847 if (!v0) 848 // Empty input value must be propagated. 849 return Value(); 850 UnaryOp unop = cast<UnaryOp>(op); 851 Region &presentRegion = unop.presentRegion(); 852 if (presentRegion.empty()) 853 // Uninitialized Value() will be interpreted as missing data in the 854 // output. 855 return Value(); 856 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 857 } 858 859 static Value buildBinaryOverlap(PatternRewriter &rewriter, Location loc, 860 Operation *op, Value v0, Value v1) { 861 if (!v0 || !v1) 862 // Empty input values must be propagated. 863 return Value(); 864 BinaryOp binop = cast<BinaryOp>(op); 865 Region &overlapRegion = binop.overlapRegion(); 866 if (overlapRegion.empty()) 867 // Uninitialized Value() will be interpreted as missing data in the 868 // output. 869 return Value(); 870 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 871 } 872 873 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, 874 Value v0, Value v1) { 875 switch (tensorExps[e].kind) { 876 case kTensor: 877 case kInvariant: 878 case kIndex: 879 llvm_unreachable("unexpected non-op"); 880 // Unary ops. 881 case kAbsF: 882 return rewriter.create<math::AbsOp>(loc, v0); 883 case kCeilF: 884 return rewriter.create<math::CeilOp>(loc, v0); 885 case kFloorF: 886 return rewriter.create<math::FloorOp>(loc, v0); 887 case kSqrtF: 888 return rewriter.create<math::SqrtOp>(loc, v0); 889 case kExpm1F: 890 return rewriter.create<math::ExpM1Op>(loc, v0); 891 case kLog1pF: 892 return rewriter.create<math::Log1pOp>(loc, v0); 893 case kSinF: 894 return rewriter.create<math::SinOp>(loc, v0); 895 case kTanhF: 896 return rewriter.create<math::TanhOp>(loc, v0); 897 case kNegF: 898 return rewriter.create<arith::NegFOp>(loc, v0); 899 case kNegI: // no negi in std 900 return rewriter.create<arith::SubIOp>( 901 loc, 902 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 903 rewriter.getZeroAttr(v0.getType())), 904 v0); 905 case kTruncF: 906 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 907 case kExtF: 908 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 909 case kCastFS: 910 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 911 case kCastFU: 912 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 913 case kCastSF: 914 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 915 case kCastUF: 916 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 917 case kCastS: 918 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 919 case kCastU: 920 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 921 case kCastIdx: 922 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 923 case kTruncI: 924 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 925 case kBitCast: 926 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 927 // Binary ops. 928 case kMulF: 929 return rewriter.create<arith::MulFOp>(loc, v0, v1); 930 case kMulI: 931 return rewriter.create<arith::MulIOp>(loc, v0, v1); 932 case kDivF: 933 return rewriter.create<arith::DivFOp>(loc, v0, v1); 934 case kDivS: 935 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 936 case kDivU: 937 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 938 case kAddF: 939 return rewriter.create<arith::AddFOp>(loc, v0, v1); 940 case kAddI: 941 return rewriter.create<arith::AddIOp>(loc, v0, v1); 942 case kSubF: 943 return rewriter.create<arith::SubFOp>(loc, v0, v1); 944 case kSubI: 945 return rewriter.create<arith::SubIOp>(loc, v0, v1); 946 case kAndI: 947 return rewriter.create<arith::AndIOp>(loc, v0, v1); 948 case kOrI: 949 return rewriter.create<arith::OrIOp>(loc, v0, v1); 950 case kXorI: 951 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 952 case kShrS: 953 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 954 case kShrU: 955 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 956 case kShlI: 957 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 958 // Semiring ops with custom logic. 959 case kBinaryBranch: 960 return insertYieldOp(rewriter, loc, 961 *tensorExps[e].op->getBlock()->getParent(), {v0}); 962 case kUnary: 963 return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); 964 case kBinary: 965 return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); 966 } 967 llvm_unreachable("unexpected expression kind in build"); 968 } 969 970 } // namespace sparse_tensor 971 } // namespace mlir 972