1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 18 namespace mlir { 19 namespace sparse_tensor { 20 21 //===----------------------------------------------------------------------===// 22 // Constructors. 23 //===----------------------------------------------------------------------===// 24 25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) 26 : kind(k), val(v), op(o) { 27 switch (kind) { 28 case kTensor: 29 assert(x != -1u && y == -1u && !v && !o); 30 tensor = x; 31 break; 32 case kInvariant: 33 assert(x == -1u && y == -1u && v && !o); 34 break; 35 case kIndex: 36 assert(x != -1u && y == -1u && !v && !o); 37 index = x; 38 break; 39 case kAbsF: 40 case kCeilF: 41 case kFloorF: 42 case kSqrtF: 43 case kExpm1F: 44 case kLog1pF: 45 case kSinF: 46 case kTanhF: 47 case kNegF: 48 case kNegI: 49 assert(x != -1u && y == -1u && !v && !o); 50 children.e0 = x; 51 children.e1 = y; 52 break; 53 case kTruncF: 54 case kExtF: 55 case kCastFS: 56 case kCastFU: 57 case kCastSF: 58 case kCastUF: 59 case kCastS: 60 case kCastU: 61 case kCastIdx: 62 case kTruncI: 63 case kBitCast: 64 assert(x != -1u && y == -1u && v && !o); 65 children.e0 = x; 66 children.e1 = y; 67 break; 68 case kBinaryBranch: 69 assert(x != -1u && y == -1u && !v && o); 70 children.e0 = x; 71 children.e1 = y; 72 break; 73 case kUnary: 74 // No assertion on y can be made, as the branching paths involve both 75 // a unary (mapSet) and binary (takeDisj) pathway. 76 assert(x != -1u && !v && o); 77 children.e0 = x; 78 children.e1 = y; 79 break; 80 case kBinary: 81 assert(x != -1u && y != -1u && !v && o); 82 children.e0 = x; 83 children.e1 = y; 84 break; 85 default: 86 assert(x != -1u && y != -1u && !v && !o); 87 children.e0 = x; 88 children.e1 = y; 89 break; 90 } 91 } 92 93 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 94 : bits(n, false), simple(), exp(e) { 95 bits.set(b); 96 } 97 98 LatPoint::LatPoint(const BitVector &b, unsigned e) 99 : bits(b), simple(), exp(e) {} 100 101 //===----------------------------------------------------------------------===// 102 // Lattice methods. 103 //===----------------------------------------------------------------------===// 104 105 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, 106 Operation *op) { 107 unsigned e = tensorExps.size(); 108 tensorExps.push_back(TensorExp(k, e0, e1, v, op)); 109 return e; 110 } 111 112 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 113 assert(t < numTensors && i < numLoops); 114 unsigned p = latPoints.size(); 115 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 116 return p; 117 } 118 119 unsigned Merger::addSet() { 120 unsigned s = latSets.size(); 121 latSets.emplace_back(SmallVector<unsigned, 16>()); 122 return s; 123 } 124 125 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, 126 Operation *op) { 127 unsigned p = latPoints.size(); 128 BitVector nb = BitVector(latPoints[p0].bits); 129 nb |= latPoints[p1].bits; 130 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); 131 latPoints.push_back(LatPoint(nb, e)); 132 return p; 133 } 134 135 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 136 unsigned s = addSet(); 137 for (unsigned p0 : latSets[s0]) 138 for (unsigned p1 : latSets[s1]) 139 latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); 140 return s; 141 } 142 143 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 144 unsigned s = takeConj(kind, s0, s1, op); 145 // Followed by all in s0. 146 for (unsigned p : latSets[s0]) 147 latSets[s].push_back(p); 148 // Map binary 0-y to unary -y. 149 // TODO: move this if-else logic into buildLattices 150 if (kind == kSubF) 151 s1 = mapSet(kNegF, s1); 152 else if (kind == kSubI) 153 s1 = mapSet(kNegI, s1); 154 // Followed by all in s1. 155 for (unsigned p : latSets[s1]) 156 latSets[s].push_back(p); 157 return s; 158 } 159 160 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, 161 bool includeLeft, Kind ltrans, Operation *opleft, 162 bool includeRight, Kind rtrans, Operation *opright) { 163 unsigned s = takeConj(kind, s0, s1, orig); 164 // Left Region. 165 if (includeLeft) { 166 if (opleft) 167 s0 = mapSet(ltrans, s0, Value(), opleft); 168 for (unsigned p : latSets[s0]) 169 latSets[s].push_back(p); 170 } 171 // Right Region. 172 if (includeRight) { 173 if (opright) 174 s1 = mapSet(rtrans, s1, Value(), opright); 175 for (unsigned p : latSets[s1]) 176 latSets[s].push_back(p); 177 } 178 return s; 179 } 180 181 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { 182 assert(kAbsF <= kind && kind <= kUnary); 183 unsigned s = addSet(); 184 for (unsigned p : latSets[s0]) { 185 unsigned e = addExp(kind, latPoints[p].exp, v, op); 186 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 187 latSets[s].push_back(latPoints.size() - 1); 188 } 189 return s; 190 } 191 192 unsigned Merger::optimizeSet(unsigned s0) { 193 unsigned s = addSet(); 194 assert(!latSets[s0].empty()); 195 unsigned p0 = latSets[s0][0]; 196 for (unsigned p1 : latSets[s0]) { 197 bool add = true; 198 if (p0 != p1) { 199 // Is this a straightforward copy? 200 unsigned e = latPoints[p1].exp; 201 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 202 continue; 203 // Conjunction already covered? 204 for (unsigned p2 : latSets[s]) { 205 assert(!latGT(p1, p2)); // Lj => Li would be bad 206 if (onlyDenseDiff(p2, p1)) { 207 add = false; 208 break; 209 } 210 } 211 assert(!add || latGT(p0, p1)); 212 } 213 if (add) 214 latSets[s].push_back(p1); 215 } 216 for (unsigned p : latSets[s]) 217 latPoints[p].simple = simplifyCond(s, p); 218 return s; 219 } 220 221 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 222 // First determine if this lattice point is a *singleton*, i.e., 223 // the last point in a lattice, no other is less than this one. 224 bool isSingleton = true; 225 for (unsigned p1 : latSets[s0]) { 226 if (p0 != p1 && latGT(p0, p1)) { 227 isSingleton = false; 228 break; 229 } 230 } 231 // Now apply the two basic rules. 232 BitVector simple = latPoints[p0].bits; 233 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 234 for (unsigned b = 0, be = simple.size(); b < be; b++) { 235 if (simple[b] && !isDim(b, kSparse)) { 236 if (reset) 237 simple.reset(b); 238 reset = true; 239 } 240 } 241 return simple; 242 } 243 244 bool Merger::latGT(unsigned i, unsigned j) const { 245 const BitVector &bitsi = latPoints[i].bits; 246 const BitVector &bitsj = latPoints[j].bits; 247 assert(bitsi.size() == bitsj.size()); 248 if (bitsi.count() > bitsj.count()) { 249 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 250 if (bitsj[b] && !bitsi[b]) 251 return false; 252 return true; 253 } 254 return false; 255 } 256 257 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 258 BitVector tmp = latPoints[j].bits; 259 tmp ^= latPoints[i].bits; 260 return !hasAnyDimOf(tmp, kSparse); 261 } 262 263 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { 264 for (unsigned b = 0, be = bits.size(); b < be; b++) 265 if (bits[b] && isDim(b, d)) 266 return true; 267 return false; 268 } 269 270 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 271 switch (tensorExps[e].kind) { 272 case kTensor: 273 return tensorExps[e].tensor == t; 274 case kAbsF: 275 case kCeilF: 276 case kFloorF: 277 case kSqrtF: 278 case kExpm1F: 279 case kLog1pF: 280 case kSinF: 281 case kTanhF: 282 case kNegF: 283 case kNegI: 284 case kTruncF: 285 case kExtF: 286 case kCastFS: 287 case kCastFU: 288 case kCastSF: 289 case kCastUF: 290 case kCastS: 291 case kCastU: 292 case kCastIdx: 293 case kTruncI: 294 case kBitCast: 295 return isSingleCondition(t, tensorExps[e].children.e0); 296 case kDivF: // note: x / c only 297 case kDivS: 298 case kDivU: 299 assert(!maybeZero(tensorExps[e].children.e1)); 300 return isSingleCondition(t, tensorExps[e].children.e0); 301 case kShrS: // note: x >> inv only 302 case kShrU: 303 case kShlI: 304 assert(isInvariant(tensorExps[e].children.e1)); 305 return isSingleCondition(t, tensorExps[e].children.e0); 306 case kMulF: 307 case kMulC: 308 case kMulI: 309 case kAndI: 310 if (isSingleCondition(t, tensorExps[e].children.e0)) 311 return isSingleCondition(t, tensorExps[e].children.e1) || 312 isInvariant(tensorExps[e].children.e1); 313 if (isSingleCondition(t, tensorExps[e].children.e1)) 314 return isInvariant(tensorExps[e].children.e0); 315 return false; 316 case kAddF: 317 case kAddC: 318 case kAddI: 319 return isSingleCondition(t, tensorExps[e].children.e0) && 320 isSingleCondition(t, tensorExps[e].children.e1); 321 default: 322 return false; 323 } 324 } 325 326 #ifndef NDEBUG 327 328 //===----------------------------------------------------------------------===// 329 // Print methods (for debugging). 330 //===----------------------------------------------------------------------===// 331 332 static const char *kindToOpSymbol(Kind kind) { 333 switch (kind) { 334 case kTensor: 335 return "tensor"; 336 case kInvariant: 337 return "invariant"; 338 case kIndex: 339 return "index"; 340 case kAbsF: 341 return "abs"; 342 case kCeilF: 343 return "ceil"; 344 case kFloorF: 345 return "floor"; 346 case kSqrtF: 347 return "sqrt"; 348 case kExpm1F: 349 return "expm1"; 350 case kLog1pF: 351 return "log1p"; 352 case kSinF: 353 return "sin"; 354 case kTanhF: 355 return "tanh"; 356 case kNegF: 357 return "-"; 358 case kNegI: 359 return "-"; 360 case kTruncF: 361 case kExtF: 362 case kCastFS: 363 case kCastFU: 364 case kCastSF: 365 case kCastUF: 366 case kCastS: 367 case kCastU: 368 case kCastIdx: 369 case kTruncI: 370 case kBitCast: 371 return "cast"; 372 case kBinaryBranch: 373 return "binary_branch"; 374 case kUnary: 375 return "unary"; 376 case kMulF: 377 case kMulC: 378 case kMulI: 379 return "*"; 380 case kDivF: 381 case kDivS: 382 case kDivU: 383 return "/"; 384 case kAddF: 385 case kAddC: 386 case kAddI: 387 return "+"; 388 case kSubF: 389 case kSubI: 390 return "-"; 391 case kAndI: 392 return "&"; 393 case kOrI: 394 return "|"; 395 case kXorI: 396 return "^"; 397 case kShrS: 398 return "a>>"; 399 case kShrU: 400 return ">>"; 401 case kShlI: 402 return "<<"; 403 case kBinary: 404 return "binary"; 405 } 406 llvm_unreachable("unexpected kind for symbol"); 407 } 408 409 void Merger::dumpExp(unsigned e) const { 410 switch (tensorExps[e].kind) { 411 case kTensor: 412 if (tensorExps[e].tensor == syntheticTensor) 413 llvm::dbgs() << "synthetic_"; 414 else if (tensorExps[e].tensor == outTensor) 415 llvm::dbgs() << "output_"; 416 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 417 break; 418 case kInvariant: 419 llvm::dbgs() << "invariant"; 420 break; 421 case kIndex: 422 llvm::dbgs() << "index_" << tensorExps[e].index; 423 break; 424 case kAbsF: 425 case kCeilF: 426 case kFloorF: 427 case kSqrtF: 428 case kExpm1F: 429 case kLog1pF: 430 case kSinF: 431 case kTanhF: 432 case kNegF: 433 case kNegI: 434 case kTruncF: 435 case kExtF: 436 case kCastFS: 437 case kCastFU: 438 case kCastSF: 439 case kCastUF: 440 case kCastS: 441 case kCastU: 442 case kCastIdx: 443 case kTruncI: 444 case kBitCast: 445 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 446 dumpExp(tensorExps[e].children.e0); 447 break; 448 default: 449 llvm::dbgs() << "("; 450 dumpExp(tensorExps[e].children.e0); 451 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 452 dumpExp(tensorExps[e].children.e1); 453 llvm::dbgs() << ")"; 454 } 455 } 456 457 void Merger::dumpLat(unsigned p) const { 458 llvm::dbgs() << "lat("; 459 dumpBits(latPoints[p].bits); 460 llvm::dbgs() << " :"; 461 dumpBits(latPoints[p].simple); 462 llvm::dbgs() << " : "; 463 dumpExp(latPoints[p].exp); 464 llvm::dbgs() << " )\n"; 465 } 466 467 void Merger::dumpSet(unsigned s) const { 468 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 469 for (unsigned p : latSets[s]) { 470 llvm::dbgs() << " "; 471 dumpLat(p); 472 } 473 llvm::dbgs() << "}\n"; 474 } 475 476 void Merger::dumpBits(const BitVector &bits) const { 477 for (unsigned b = 0, be = bits.size(); b < be; b++) { 478 if (bits[b]) { 479 unsigned t = tensor(b); 480 unsigned i = index(b); 481 llvm::dbgs() << " i_" << t << "_" << i << "_"; 482 switch (dims[t][i]) { 483 case kSparse: 484 llvm::dbgs() << "S"; 485 break; 486 case kDense: 487 llvm::dbgs() << "D"; 488 break; 489 case kSingle: 490 llvm::dbgs() << "T"; 491 break; 492 case kUndef: 493 llvm::dbgs() << "U"; 494 break; 495 } 496 } 497 } 498 } 499 500 #endif // NDEBUG 501 502 //===----------------------------------------------------------------------===// 503 // Builder methods. 504 //===----------------------------------------------------------------------===// 505 506 unsigned Merger::buildLattices(unsigned e, unsigned i) { 507 Kind kind = tensorExps[e].kind; 508 switch (kind) { 509 case kTensor: 510 case kInvariant: 511 case kIndex: { 512 // Either the index is really used in the tensor expression, or it is 513 // set to the undefined index in that dimension. An invariant expression, 514 // a proper index value, and a truly dynamic sparse output tensor are set 515 // to a synthetic tensor with undefined indices only to ensure the 516 // iteration space is not skipped as a result of their contents. 517 unsigned s = addSet(); 518 unsigned t = syntheticTensor; 519 if (kind == kTensor) { 520 t = tensorExps[e].tensor; 521 if (hasSparseOut && t == outTensor) 522 t = syntheticTensor; 523 } 524 latSets[s].push_back(addLat(t, i, e)); 525 return s; 526 } 527 case kAbsF: 528 case kCeilF: 529 case kFloorF: 530 case kSqrtF: 531 case kExpm1F: 532 case kLog1pF: 533 case kSinF: 534 case kTanhF: 535 case kNegF: 536 case kNegI: 537 case kTruncF: 538 case kExtF: 539 case kCastFS: 540 case kCastFU: 541 case kCastSF: 542 case kCastUF: 543 case kCastS: 544 case kCastU: 545 case kCastIdx: 546 case kTruncI: 547 case kBitCast: 548 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 549 // lattice set of the operand through the operator into a new set. 550 // 551 // -y|!y | y | 552 // --+---+---+ 553 // | 0 |-y | 554 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 555 tensorExps[e].val); 556 case kBinaryBranch: 557 // The left or right half of a binary operation which has already 558 // been split into separate operations for each region. 559 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), 560 tensorExps[e].op); 561 case kUnary: 562 // A custom unary operation. 563 // 564 // op y| !y | y | 565 // ----+----------+------------+ 566 // | absent() | present(y) | 567 { 568 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 569 UnaryOp unop = cast<UnaryOp>(tensorExps[e].op); 570 Region &absentRegion = unop.absentRegion(); 571 572 if (absentRegion.empty()) { 573 // Simple mapping over existing values. 574 return mapSet(kind, child0, Value(), unop); 575 } // Use a disjunction with `unop` on the left and the absent value as an 576 // invariant on the right. 577 Block &absentBlock = absentRegion.front(); 578 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 579 Value absentVal = absentYield.result(); 580 unsigned rhs = addExp(kInvariant, absentVal); 581 return takeDisj(kind, child0, buildLattices(rhs, i), unop); 582 } 583 case kMulF: 584 case kMulC: 585 case kMulI: 586 case kAndI: 587 // A multiplicative operation only needs to be performed 588 // for the conjunction of sparse iteration spaces. 589 // 590 // x*y|!y | y | 591 // ---+---+---+ 592 // !x | 0 | 0 | 593 // x | 0 |x*y| 594 // 595 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 596 return takeConj(kind, // take binary conjunction 597 buildLattices(tensorExps[e].children.e0, i), 598 buildLattices(tensorExps[e].children.e1, i)); 599 case kDivF: 600 case kDivS: 601 case kDivU: 602 // A division is tricky, since 0/0, 0/c, c/0 all have 603 // specific outcomes for floating-point and integers. 604 // Thus, we need to traverse the full iteration space. 605 // 606 // x/y|!y | y | 607 // ---+---+---+ 608 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 609 // x |x/0|x/y| INT: x/0=exception for any x 610 // 611 // TODO: for now we "fixed" this by only accepting x/c cases 612 // during expression building, so that the conjunction 613 // rules applies (viz. x/c = x*(1/c) as far as lattice 614 // construction is concerned). 615 assert(!maybeZero(tensorExps[e].children.e1)); 616 return takeConj(kind, // take binary conjunction 617 buildLattices(tensorExps[e].children.e0, i), 618 buildLattices(tensorExps[e].children.e1, i)); 619 case kAddF: 620 case kAddC: 621 case kAddI: 622 case kSubF: 623 case kSubI: 624 case kOrI: 625 case kXorI: 626 // An additive operation needs to be performed 627 // for the disjunction of sparse iteration spaces. 628 // 629 // x+y|!y | y | x-y|!y | y | 630 // ---+---+---+ ---+---+---+ 631 // !x | 0 | y | !x | 0 |-y | 632 // x | x |x+y| x | x |x-y| 633 return takeDisj(kind, // take binary disjunction 634 buildLattices(tensorExps[e].children.e0, i), 635 buildLattices(tensorExps[e].children.e1, i)); 636 case kShrS: 637 case kShrU: 638 case kShlI: 639 // A shift operation by an invariant amount (viz. tensor expressions 640 // can only occur at the left-hand-side of the operator) can be handled 641 // with the conjuction rule. 642 assert(isInvariant(tensorExps[e].children.e1)); 643 return takeConj(kind, // take binary conjunction 644 buildLattices(tensorExps[e].children.e0, i), 645 buildLattices(tensorExps[e].children.e1, i)); 646 case kBinary: 647 // A custom binary operation. 648 // 649 // x op y| !y | y | 650 // ------+---------+--------------+ 651 // !x | empty | right(y) | 652 // x | left(x) | overlap(x,y) | 653 { 654 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 655 unsigned child1 = buildLattices(tensorExps[e].children.e1, i); 656 BinaryOp binop = cast<BinaryOp>(tensorExps[e].op); 657 Region &leftRegion = binop.leftRegion(); 658 Region &rightRegion = binop.rightRegion(); 659 // Left Region. 660 Operation *leftYield = nullptr; 661 if (!leftRegion.empty()) { 662 Block &leftBlock = leftRegion.front(); 663 leftYield = leftBlock.getTerminator(); 664 } 665 // Right Region. 666 Operation *rightYield = nullptr; 667 if (!rightRegion.empty()) { 668 Block &rightBlock = rightRegion.front(); 669 rightYield = rightBlock.getTerminator(); 670 } 671 bool includeLeft = binop.left_identity() || !leftRegion.empty(); 672 bool includeRight = binop.right_identity() || !rightRegion.empty(); 673 return takeCombi(kBinary, child0, child1, binop, includeLeft, 674 kBinaryBranch, leftYield, includeRight, kBinaryBranch, 675 rightYield); 676 } 677 } 678 llvm_unreachable("unexpected expression kind"); 679 } 680 681 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 682 Operation *yield = op.region().front().getTerminator(); 683 return buildTensorExp(op, yield->getOperand(0)); 684 } 685 686 /// Only returns false if we are certain this is a nonzero. 687 bool Merger::maybeZero(unsigned e) const { 688 if (tensorExps[e].kind == kInvariant) { 689 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 690 return c.value() == 0; 691 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 692 return c.value().isZero(); 693 } 694 return true; 695 } 696 697 bool Merger::isInvariant(unsigned e) const { 698 return tensorExps[e].kind == kInvariant; 699 } 700 701 Type Merger::inferType(unsigned e, Value src) { 702 // Obtain the destination type from the cast node. 703 Type dtp = tensorExps[e].val.getType(); 704 // Inspect source type. For vector types, apply the same 705 // vectorization to the destination type. 706 if (auto vtp = src.getType().dyn_cast<VectorType>()) 707 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 708 return dtp; 709 } 710 711 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 712 if (auto arg = v.dyn_cast<BlockArgument>()) { 713 unsigned argN = arg.getArgNumber(); 714 // Any argument of the generic op that is not marked as a scalar 715 // argument is considered a tensor, indexed by the implicit loop 716 // bounds. This includes rank-0 tensor arguments. 717 if (arg.getOwner()->getParentOp() == op) { 718 OpOperand *t = op.getInputAndOutputOperands()[argN]; 719 if (!op.isScalar(t)) 720 return addExp(kTensor, argN); 721 v = t->get(); // get scalar value 722 } 723 // Any other argument (marked as scalar argument for the generic op 724 // or belonging to an enveloping op) is considered invariant. 725 return addExp(kInvariant, v); 726 } 727 // Something defined outside is invariant. 728 Operation *def = v.getDefiningOp(); 729 if (def->getBlock() != &op.region().front()) 730 return addExp(kInvariant, v); 731 // Construct index operations. 732 if (def->getNumOperands() == 0) { 733 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 734 return addExp(kIndex, indexOp.dim()); 735 } 736 // Construct unary operations if subexpression can be built. 737 if (def->getNumOperands() == 1) { 738 auto x = buildTensorExp(op, def->getOperand(0)); 739 if (x.hasValue()) { 740 unsigned e = x.getValue(); 741 if (isa<math::AbsOp>(def)) 742 return addExp(kAbsF, e); 743 if (isa<math::CeilOp>(def)) 744 return addExp(kCeilF, e); 745 if (isa<math::FloorOp>(def)) 746 return addExp(kFloorF, e); 747 if (isa<math::SqrtOp>(def)) 748 return addExp(kSqrtF, e); 749 if (isa<math::ExpM1Op>(def)) 750 return addExp(kExpm1F, e); 751 if (isa<math::Log1pOp>(def)) 752 return addExp(kLog1pF, e); 753 if (isa<math::SinOp>(def)) 754 return addExp(kSinF, e); 755 if (isa<math::TanhOp>(def)) 756 return addExp(kTanhF, e); 757 if (isa<arith::NegFOp>(def)) 758 return addExp(kNegF, e); // no negi in std 759 if (isa<arith::TruncFOp>(def)) 760 return addExp(kTruncF, e, v); 761 if (isa<arith::ExtFOp>(def)) 762 return addExp(kExtF, e, v); 763 if (isa<arith::FPToSIOp>(def)) 764 return addExp(kCastFS, e, v); 765 if (isa<arith::FPToUIOp>(def)) 766 return addExp(kCastFU, e, v); 767 if (isa<arith::SIToFPOp>(def)) 768 return addExp(kCastSF, e, v); 769 if (isa<arith::UIToFPOp>(def)) 770 return addExp(kCastUF, e, v); 771 if (isa<arith::ExtSIOp>(def)) 772 return addExp(kCastS, e, v); 773 if (isa<arith::ExtUIOp>(def)) 774 return addExp(kCastU, e, v); 775 if (isa<arith::IndexCastOp>(def)) 776 return addExp(kCastIdx, e, v); 777 if (isa<arith::TruncIOp>(def)) 778 return addExp(kTruncI, e, v); 779 if (isa<arith::BitcastOp>(def)) 780 return addExp(kBitCast, e, v); 781 if (isa<sparse_tensor::UnaryOp>(def)) 782 return addExp(kUnary, e, Value(), def); 783 } 784 } 785 // Construct binary operations if subexpressions can be built. 786 // See buildLattices() for an explanation of rejecting certain 787 // division and shift operations 788 if (def->getNumOperands() == 2) { 789 auto x = buildTensorExp(op, def->getOperand(0)); 790 auto y = buildTensorExp(op, def->getOperand(1)); 791 if (x.hasValue() && y.hasValue()) { 792 unsigned e0 = x.getValue(); 793 unsigned e1 = y.getValue(); 794 if (isa<arith::MulFOp>(def)) 795 return addExp(kMulF, e0, e1); 796 if (isa<complex::MulOp>(def)) 797 return addExp(kMulC, e0, e1); 798 if (isa<arith::MulIOp>(def)) 799 return addExp(kMulI, e0, e1); 800 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 801 return addExp(kDivF, e0, e1); 802 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 803 return addExp(kDivS, e0, e1); 804 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 805 return addExp(kDivU, e0, e1); 806 if (isa<arith::AddFOp>(def)) 807 return addExp(kAddF, e0, e1); 808 if (isa<complex::AddOp>(def)) 809 return addExp(kAddC, e0, e1); 810 if (isa<arith::AddIOp>(def)) 811 return addExp(kAddI, e0, e1); 812 if (isa<arith::SubFOp>(def)) 813 return addExp(kSubF, e0, e1); 814 if (isa<arith::SubIOp>(def)) 815 return addExp(kSubI, e0, e1); 816 if (isa<arith::AndIOp>(def)) 817 return addExp(kAndI, e0, e1); 818 if (isa<arith::OrIOp>(def)) 819 return addExp(kOrI, e0, e1); 820 if (isa<arith::XOrIOp>(def)) 821 return addExp(kXorI, e0, e1); 822 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 823 return addExp(kShrS, e0, e1); 824 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 825 return addExp(kShrU, e0, e1); 826 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 827 return addExp(kShlI, e0, e1); 828 if (isa<sparse_tensor::BinaryOp>(def)) 829 return addExp(kBinary, e0, e1, Value(), def); 830 } 831 } 832 // Cannot build. 833 return None; 834 } 835 836 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 837 ValueRange vals) { 838 // Make a clone of overlap region. 839 Region tmpRegion; 840 BlockAndValueMapping mapper; 841 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 842 Block &clonedBlock = tmpRegion.front(); 843 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 844 // Merge cloned block and return yield value. 845 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 846 rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); 847 Value val = clonedYield.result(); 848 rewriter.eraseOp(clonedYield); 849 rewriter.eraseOp(placeholder); 850 return val; 851 } 852 853 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 854 Operation *op, Value v0) { 855 if (!v0) 856 // Empty input value must be propagated. 857 return Value(); 858 UnaryOp unop = cast<UnaryOp>(op); 859 Region &presentRegion = unop.presentRegion(); 860 if (presentRegion.empty()) 861 // Uninitialized Value() will be interpreted as missing data in the 862 // output. 863 return Value(); 864 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 865 } 866 867 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 868 Operation *op, Value v0, Value v1) { 869 if (!v0 || !v1) 870 // Empty input values must be propagated. 871 return Value(); 872 BinaryOp binop = cast<BinaryOp>(op); 873 Region &overlapRegion = binop.overlapRegion(); 874 if (overlapRegion.empty()) 875 // Uninitialized Value() will be interpreted as missing data in the 876 // output. 877 return Value(); 878 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 879 } 880 881 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, 882 Value v0, Value v1) { 883 switch (tensorExps[e].kind) { 884 case kTensor: 885 case kInvariant: 886 case kIndex: 887 llvm_unreachable("unexpected non-op"); 888 // Unary ops. 889 case kAbsF: 890 return rewriter.create<math::AbsOp>(loc, v0); 891 case kCeilF: 892 return rewriter.create<math::CeilOp>(loc, v0); 893 case kFloorF: 894 return rewriter.create<math::FloorOp>(loc, v0); 895 case kSqrtF: 896 return rewriter.create<math::SqrtOp>(loc, v0); 897 case kExpm1F: 898 return rewriter.create<math::ExpM1Op>(loc, v0); 899 case kLog1pF: 900 return rewriter.create<math::Log1pOp>(loc, v0); 901 case kSinF: 902 return rewriter.create<math::SinOp>(loc, v0); 903 case kTanhF: 904 return rewriter.create<math::TanhOp>(loc, v0); 905 case kNegF: 906 return rewriter.create<arith::NegFOp>(loc, v0); 907 case kNegI: // no negi in std 908 return rewriter.create<arith::SubIOp>( 909 loc, 910 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 911 rewriter.getZeroAttr(v0.getType())), 912 v0); 913 case kTruncF: 914 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 915 case kExtF: 916 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 917 case kCastFS: 918 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 919 case kCastFU: 920 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 921 case kCastSF: 922 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 923 case kCastUF: 924 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 925 case kCastS: 926 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 927 case kCastU: 928 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 929 case kCastIdx: 930 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 931 case kTruncI: 932 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 933 case kBitCast: 934 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 935 // Binary ops. 936 case kMulF: 937 return rewriter.create<arith::MulFOp>(loc, v0, v1); 938 case kMulC: 939 return rewriter.create<complex::MulOp>(loc, v0, v1); 940 case kMulI: 941 return rewriter.create<arith::MulIOp>(loc, v0, v1); 942 case kDivF: 943 return rewriter.create<arith::DivFOp>(loc, v0, v1); 944 case kDivS: 945 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 946 case kDivU: 947 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 948 case kAddF: 949 return rewriter.create<arith::AddFOp>(loc, v0, v1); 950 case kAddC: 951 return rewriter.create<complex::AddOp>(loc, v0, v1); 952 case kAddI: 953 return rewriter.create<arith::AddIOp>(loc, v0, v1); 954 case kSubF: 955 return rewriter.create<arith::SubFOp>(loc, v0, v1); 956 case kSubI: 957 return rewriter.create<arith::SubIOp>(loc, v0, v1); 958 case kAndI: 959 return rewriter.create<arith::AndIOp>(loc, v0, v1); 960 case kOrI: 961 return rewriter.create<arith::OrIOp>(loc, v0, v1); 962 case kXorI: 963 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 964 case kShrS: 965 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 966 case kShrU: 967 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 968 case kShlI: 969 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 970 // Semiring ops with custom logic. 971 case kBinaryBranch: 972 return insertYieldOp(rewriter, loc, 973 *tensorExps[e].op->getBlock()->getParent(), {v0}); 974 case kUnary: 975 return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); 976 case kBinary: 977 return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); 978 } 979 llvm_unreachable("unexpected expression kind in build"); 980 } 981 982 } // namespace sparse_tensor 983 } // namespace mlir 984