1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 14 #include "mlir/IR/Operation.h" 15 #include "llvm/Support/Debug.h" 16 17 namespace mlir { 18 namespace sparse_tensor { 19 20 //===----------------------------------------------------------------------===// 21 // Constructors. 22 //===----------------------------------------------------------------------===// 23 24 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) 25 : kind(k), val(v), op(o) { 26 switch (kind) { 27 case kTensor: 28 assert(x != -1u && y == -1u && !v && !o); 29 tensor = x; 30 break; 31 case kInvariant: 32 assert(x == -1u && y == -1u && v && !o); 33 break; 34 case kIndex: 35 assert(x != -1u && y == -1u && !v && !o); 36 index = x; 37 break; 38 case kAbsF: 39 case kCeilF: 40 case kFloorF: 41 case kNegF: 42 case kNegI: 43 assert(x != -1u && y == -1u && !v && !o); 44 children.e0 = x; 45 children.e1 = y; 46 break; 47 case kTruncF: 48 case kExtF: 49 case kCastFS: 50 case kCastFU: 51 case kCastSF: 52 case kCastUF: 53 case kCastS: 54 case kCastU: 55 case kCastIdx: 56 case kTruncI: 57 case kBitCast: 58 assert(x != -1u && y == -1u && v && !o); 59 children.e0 = x; 60 children.e1 = y; 61 break; 62 case kBinaryBranch: 63 assert(x != -1u && y == -1u && !v && o); 64 children.e0 = x; 65 children.e1 = y; 66 break; 67 case kUnary: 68 // No assertion on y can be made, as the branching paths involve both 69 // a unary (mapSet) and binary (takeDisj) pathway. 70 assert(x != -1u && !v && o); 71 children.e0 = x; 72 children.e1 = y; 73 break; 74 case kBinary: 75 assert(x != -1u && y != -1u && !v && o); 76 children.e0 = x; 77 children.e1 = y; 78 break; 79 default: 80 assert(x != -1u && y != -1u && !v && !o); 81 children.e0 = x; 82 children.e1 = y; 83 break; 84 } 85 } 86 87 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 88 : bits(n, false), simple(), exp(e) { 89 bits.set(b); 90 } 91 92 LatPoint::LatPoint(const BitVector &b, unsigned e) 93 : bits(b), simple(), exp(e) {} 94 95 //===----------------------------------------------------------------------===// 96 // Lattice methods. 97 //===----------------------------------------------------------------------===// 98 99 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, 100 Operation *op) { 101 unsigned e = tensorExps.size(); 102 tensorExps.push_back(TensorExp(k, e0, e1, v, op)); 103 return e; 104 } 105 106 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 107 assert(t < numTensors && i < numLoops); 108 unsigned p = latPoints.size(); 109 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 110 return p; 111 } 112 113 unsigned Merger::addSet() { 114 unsigned s = latSets.size(); 115 latSets.emplace_back(SmallVector<unsigned, 16>()); 116 return s; 117 } 118 119 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, 120 Operation *op) { 121 unsigned p = latPoints.size(); 122 BitVector nb = BitVector(latPoints[p0].bits); 123 nb |= latPoints[p1].bits; 124 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); 125 latPoints.push_back(LatPoint(nb, e)); 126 return p; 127 } 128 129 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 130 unsigned s = addSet(); 131 for (unsigned p0 : latSets[s0]) 132 for (unsigned p1 : latSets[s1]) 133 latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); 134 return s; 135 } 136 137 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 138 unsigned s = takeConj(kind, s0, s1, op); 139 // Followed by all in s0. 140 for (unsigned p : latSets[s0]) 141 latSets[s].push_back(p); 142 // Map binary 0-y to unary -y. 143 // TODO: move this if-else logic into buildLattices 144 if (kind == kSubF) 145 s1 = mapSet(kNegF, s1); 146 else if (kind == kSubI) 147 s1 = mapSet(kNegI, s1); 148 // Followed by all in s1. 149 for (unsigned p : latSets[s1]) 150 latSets[s].push_back(p); 151 return s; 152 } 153 154 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, 155 bool includeLeft, Kind ltrans, Operation *opleft, 156 bool includeRight, Kind rtrans, Operation *opright) { 157 unsigned s = takeConj(kind, s0, s1, orig); 158 // Left Region. 159 if (includeLeft) { 160 if (opleft) 161 s0 = mapSet(ltrans, s0, Value(), opleft); 162 for (unsigned p : latSets[s0]) 163 latSets[s].push_back(p); 164 } 165 // Right Region. 166 if (includeRight) { 167 if (opright) 168 s1 = mapSet(rtrans, s1, Value(), opright); 169 for (unsigned p : latSets[s1]) 170 latSets[s].push_back(p); 171 } 172 return s; 173 } 174 175 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { 176 assert(kAbsF <= kind && kind <= kUnary); 177 unsigned s = addSet(); 178 for (unsigned p : latSets[s0]) { 179 unsigned e = addExp(kind, latPoints[p].exp, v, op); 180 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 181 latSets[s].push_back(latPoints.size() - 1); 182 } 183 return s; 184 } 185 186 unsigned Merger::optimizeSet(unsigned s0) { 187 unsigned s = addSet(); 188 assert(!latSets[s0].empty()); 189 unsigned p0 = latSets[s0][0]; 190 for (unsigned p1 : latSets[s0]) { 191 bool add = true; 192 if (p0 != p1) { 193 // Is this a straightforward copy? 194 unsigned e = latPoints[p1].exp; 195 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 196 continue; 197 // Conjunction already covered? 198 for (unsigned p2 : latSets[s]) { 199 assert(!latGT(p1, p2)); // Lj => Li would be bad 200 if (onlyDenseDiff(p2, p1)) { 201 add = false; 202 break; 203 } 204 } 205 assert(!add || latGT(p0, p1)); 206 } 207 if (add) 208 latSets[s].push_back(p1); 209 } 210 for (unsigned p : latSets[s]) 211 latPoints[p].simple = simplifyCond(s, p); 212 return s; 213 } 214 215 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 216 // First determine if this lattice point is a *singleton*, i.e., 217 // the last point in a lattice, no other is less than this one. 218 bool isSingleton = true; 219 for (unsigned p1 : latSets[s0]) { 220 if (p0 != p1 && latGT(p0, p1)) { 221 isSingleton = false; 222 break; 223 } 224 } 225 // Now apply the two basic rules. 226 BitVector simple = latPoints[p0].bits; 227 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 228 for (unsigned b = 0, be = simple.size(); b < be; b++) { 229 if (simple[b] && !isDim(b, kSparse)) { 230 if (reset) 231 simple.reset(b); 232 reset = true; 233 } 234 } 235 return simple; 236 } 237 238 bool Merger::latGT(unsigned i, unsigned j) const { 239 const BitVector &bitsi = latPoints[i].bits; 240 const BitVector &bitsj = latPoints[j].bits; 241 assert(bitsi.size() == bitsj.size()); 242 if (bitsi.count() > bitsj.count()) { 243 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 244 if (bitsj[b] && !bitsi[b]) 245 return false; 246 return true; 247 } 248 return false; 249 } 250 251 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 252 BitVector tmp = latPoints[j].bits; 253 tmp ^= latPoints[i].bits; 254 return !hasAnyDimOf(tmp, kSparse); 255 } 256 257 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { 258 for (unsigned b = 0, be = bits.size(); b < be; b++) 259 if (bits[b] && isDim(b, d)) 260 return true; 261 return false; 262 } 263 264 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 265 switch (tensorExps[e].kind) { 266 case kTensor: 267 return tensorExps[e].tensor == t; 268 case kAbsF: 269 case kCeilF: 270 case kFloorF: 271 case kNegF: 272 case kNegI: 273 case kTruncF: 274 case kExtF: 275 case kCastFS: 276 case kCastFU: 277 case kCastSF: 278 case kCastUF: 279 case kCastS: 280 case kCastU: 281 case kCastIdx: 282 case kTruncI: 283 case kBitCast: 284 return isSingleCondition(t, tensorExps[e].children.e0); 285 case kDivF: // note: x / c only 286 case kDivS: 287 case kDivU: 288 assert(!maybeZero(tensorExps[e].children.e1)); 289 return isSingleCondition(t, tensorExps[e].children.e0); 290 case kShrS: // note: x >> inv only 291 case kShrU: 292 case kShlI: 293 assert(isInvariant(tensorExps[e].children.e1)); 294 return isSingleCondition(t, tensorExps[e].children.e0); 295 case kMulF: 296 case kMulI: 297 case kAndI: 298 if (isSingleCondition(t, tensorExps[e].children.e0)) 299 return isSingleCondition(t, tensorExps[e].children.e1) || 300 isInvariant(tensorExps[e].children.e1); 301 if (isSingleCondition(t, tensorExps[e].children.e1)) 302 return isInvariant(tensorExps[e].children.e0); 303 return false; 304 case kAddF: 305 case kAddI: 306 return isSingleCondition(t, tensorExps[e].children.e0) && 307 isSingleCondition(t, tensorExps[e].children.e1); 308 default: 309 return false; 310 } 311 } 312 313 #ifndef NDEBUG 314 315 //===----------------------------------------------------------------------===// 316 // Print methods (for debugging). 317 //===----------------------------------------------------------------------===// 318 319 static const char *kindToOpSymbol(Kind kind) { 320 switch (kind) { 321 case kTensor: 322 return "tensor"; 323 case kInvariant: 324 return "invariant"; 325 case kIndex: 326 return "index"; 327 case kAbsF: 328 return "abs"; 329 case kCeilF: 330 return "ceil"; 331 case kFloorF: 332 return "floor"; 333 case kNegF: 334 return "-"; 335 case kNegI: 336 return "-"; 337 case kTruncF: 338 case kExtF: 339 case kCastFS: 340 case kCastFU: 341 case kCastSF: 342 case kCastUF: 343 case kCastS: 344 case kCastU: 345 case kCastIdx: 346 case kTruncI: 347 case kBitCast: 348 return "cast"; 349 case kBinaryBranch: 350 return "binary_branch"; 351 case kUnary: 352 return "unary"; 353 case kMulF: 354 return "*"; 355 case kMulI: 356 return "*"; 357 case kDivF: 358 return "/"; 359 case kDivS: 360 return "/"; 361 case kDivU: 362 return "/"; 363 case kAddF: 364 return "+"; 365 case kAddI: 366 return "+"; 367 case kSubF: 368 return "-"; 369 case kSubI: 370 return "-"; 371 case kAndI: 372 return "&"; 373 case kOrI: 374 return "|"; 375 case kXorI: 376 return "^"; 377 case kShrS: 378 return "a>>"; 379 case kShrU: 380 return ">>"; 381 case kShlI: 382 return "<<"; 383 case kBinary: 384 return "binary"; 385 } 386 llvm_unreachable("unexpected kind for symbol"); 387 } 388 389 void Merger::dumpExp(unsigned e) const { 390 switch (tensorExps[e].kind) { 391 case kTensor: 392 if (tensorExps[e].tensor == syntheticTensor) 393 llvm::dbgs() << "synthetic_"; 394 else if (tensorExps[e].tensor == outTensor) 395 llvm::dbgs() << "output_"; 396 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 397 break; 398 case kInvariant: 399 llvm::dbgs() << "invariant"; 400 break; 401 case kIndex: 402 llvm::dbgs() << "index_" << tensorExps[e].index; 403 break; 404 case kAbsF: 405 case kCeilF: 406 case kFloorF: 407 case kNegF: 408 case kNegI: 409 case kTruncF: 410 case kExtF: 411 case kCastFS: 412 case kCastFU: 413 case kCastSF: 414 case kCastUF: 415 case kCastS: 416 case kCastU: 417 case kCastIdx: 418 case kTruncI: 419 case kBitCast: 420 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 421 dumpExp(tensorExps[e].children.e0); 422 break; 423 default: 424 llvm::dbgs() << "("; 425 dumpExp(tensorExps[e].children.e0); 426 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 427 dumpExp(tensorExps[e].children.e1); 428 llvm::dbgs() << ")"; 429 } 430 } 431 432 void Merger::dumpLat(unsigned p) const { 433 llvm::dbgs() << "lat("; 434 dumpBits(latPoints[p].bits); 435 llvm::dbgs() << " :"; 436 dumpBits(latPoints[p].simple); 437 llvm::dbgs() << " : "; 438 dumpExp(latPoints[p].exp); 439 llvm::dbgs() << " )\n"; 440 } 441 442 void Merger::dumpSet(unsigned s) const { 443 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 444 for (unsigned p : latSets[s]) { 445 llvm::dbgs() << " "; 446 dumpLat(p); 447 } 448 llvm::dbgs() << "}\n"; 449 } 450 451 void Merger::dumpBits(const BitVector &bits) const { 452 for (unsigned b = 0, be = bits.size(); b < be; b++) { 453 if (bits[b]) { 454 unsigned t = tensor(b); 455 unsigned i = index(b); 456 llvm::dbgs() << " i_" << t << "_" << i << "_"; 457 switch (dims[t][i]) { 458 case kSparse: 459 llvm::dbgs() << "S"; 460 break; 461 case kDense: 462 llvm::dbgs() << "D"; 463 break; 464 case kSingle: 465 llvm::dbgs() << "T"; 466 break; 467 case kUndef: 468 llvm::dbgs() << "U"; 469 break; 470 } 471 } 472 } 473 } 474 475 #endif // NDEBUG 476 477 //===----------------------------------------------------------------------===// 478 // Builder methods. 479 //===----------------------------------------------------------------------===// 480 481 unsigned Merger::buildLattices(unsigned e, unsigned i) { 482 Kind kind = tensorExps[e].kind; 483 switch (kind) { 484 case kTensor: 485 case kInvariant: 486 case kIndex: { 487 // Either the index is really used in the tensor expression, or it is 488 // set to the undefined index in that dimension. An invariant expression, 489 // a proper index value, and a truly dynamic sparse output tensor are set 490 // to a synthetic tensor with undefined indices only to ensure the 491 // iteration space is not skipped as a result of their contents. 492 unsigned s = addSet(); 493 unsigned t = syntheticTensor; 494 if (kind == kTensor) { 495 t = tensorExps[e].tensor; 496 if (hasSparseOut && t == outTensor) 497 t = syntheticTensor; 498 } 499 latSets[s].push_back(addLat(t, i, e)); 500 return s; 501 } 502 case kAbsF: 503 case kCeilF: 504 case kFloorF: 505 case kNegF: 506 case kNegI: 507 case kTruncF: 508 case kExtF: 509 case kCastFS: 510 case kCastFU: 511 case kCastSF: 512 case kCastUF: 513 case kCastS: 514 case kCastU: 515 case kCastIdx: 516 case kTruncI: 517 case kBitCast: 518 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 519 // lattice set of the operand through the operator into a new set. 520 // 521 // -y|!y | y | 522 // --+---+---+ 523 // | 0 |-y | 524 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 525 tensorExps[e].val); 526 case kBinaryBranch: 527 // The left or right half of a binary operation which has already 528 // been split into separate operations for each region. 529 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), 530 tensorExps[e].op); 531 case kUnary: 532 // A custom unary operation. 533 // 534 // op y| !y | y | 535 // ----+----------+------------+ 536 // | absent() | present(y) | 537 { 538 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 539 UnaryOp unop = cast<UnaryOp>(tensorExps[e].op); 540 Region &absentRegion = unop.absentRegion(); 541 542 if (absentRegion.empty()) { 543 // Simple mapping over existing values. 544 return mapSet(kind, child0, Value(), unop); 545 } else { 546 // Use a disjunction with `unop` on the left and the absent value as an 547 // invariant on the right. 548 Block &absentBlock = absentRegion.front(); 549 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 550 Value absentVal = absentYield.result(); 551 unsigned rhs = addExp(kInvariant, absentVal); 552 return takeDisj(kind, child0, buildLattices(rhs, i), unop); 553 } 554 } 555 case kMulF: 556 case kMulI: 557 case kAndI: 558 // A multiplicative operation only needs to be performed 559 // for the conjunction of sparse iteration spaces. 560 // 561 // x*y|!y | y | 562 // ---+---+---+ 563 // !x | 0 | 0 | 564 // x | 0 |x*y| 565 return takeConj(kind, // take binary conjunction 566 buildLattices(tensorExps[e].children.e0, i), 567 buildLattices(tensorExps[e].children.e1, i)); 568 case kDivF: 569 case kDivS: 570 case kDivU: 571 // A division is tricky, since 0/0, 0/c, c/0 all have 572 // specific outcomes for floating-point and integers. 573 // Thus, we need to traverse the full iteration space. 574 // 575 // x/y|!y | y | 576 // ---+---+---+ 577 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 578 // x |x/0|x/y| INT: x/0=exception for any x 579 // 580 // TODO: for now we "fixed" this by only accepting x/c cases 581 // during expression building, so that the conjunction 582 // rules applies (viz. x/c = x*(1/c) as far as lattice 583 // construction is concerned). 584 assert(!maybeZero(tensorExps[e].children.e1)); 585 return takeConj(kind, // take binary conjunction 586 buildLattices(tensorExps[e].children.e0, i), 587 buildLattices(tensorExps[e].children.e1, i)); 588 case kAddF: 589 case kAddI: 590 case kSubF: 591 case kSubI: 592 case kOrI: 593 case kXorI: 594 // An additive operation needs to be performed 595 // for the disjunction of sparse iteration spaces. 596 // 597 // x+y|!y | y | x-y|!y | y | 598 // ---+---+---+ ---+---+---+ 599 // !x | 0 | y | !x | 0 |-y | 600 // x | x |x+y| x | x |x-y| 601 return takeDisj(kind, // take binary disjunction 602 buildLattices(tensorExps[e].children.e0, i), 603 buildLattices(tensorExps[e].children.e1, i)); 604 case kShrS: 605 case kShrU: 606 case kShlI: 607 // A shift operation by an invariant amount (viz. tensor expressions 608 // can only occur at the left-hand-side of the operator) can be handled 609 // with the conjuction rule. 610 assert(isInvariant(tensorExps[e].children.e1)); 611 return takeConj(kind, // take binary conjunction 612 buildLattices(tensorExps[e].children.e0, i), 613 buildLattices(tensorExps[e].children.e1, i)); 614 case kBinary: 615 // A custom binary operation. 616 // 617 // x op y| !y | y | 618 // ------+---------+--------------+ 619 // !x | empty | right(y) | 620 // x | left(x) | overlap(x,y) | 621 { 622 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 623 unsigned child1 = buildLattices(tensorExps[e].children.e1, i); 624 BinaryOp binop = cast<BinaryOp>(tensorExps[e].op); 625 Region &leftRegion = binop.leftRegion(); 626 Region &rightRegion = binop.rightRegion(); 627 // Left Region. 628 Operation *leftYield = nullptr; 629 if (!leftRegion.empty()) { 630 Block &leftBlock = leftRegion.front(); 631 leftYield = leftBlock.getTerminator(); 632 } 633 // Right Region. 634 Operation *rightYield = nullptr; 635 if (!rightRegion.empty()) { 636 Block &rightBlock = rightRegion.front(); 637 rightYield = rightBlock.getTerminator(); 638 } 639 bool includeLeft = binop.left_identity() || !leftRegion.empty(); 640 bool includeRight = binop.right_identity() || !rightRegion.empty(); 641 return takeCombi(kBinary, child0, child1, binop, includeLeft, 642 kBinaryBranch, leftYield, includeRight, kBinaryBranch, 643 rightYield); 644 } 645 } 646 llvm_unreachable("unexpected expression kind"); 647 } 648 649 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 650 Operation *yield = op.region().front().getTerminator(); 651 return buildTensorExp(op, yield->getOperand(0)); 652 } 653 654 /// Only returns false if we are certain this is a nonzero. 655 bool Merger::maybeZero(unsigned e) const { 656 if (tensorExps[e].kind == kInvariant) { 657 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 658 return c.value() == 0; 659 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 660 return c.value().isZero(); 661 } 662 return true; 663 } 664 665 bool Merger::isInvariant(unsigned e) const { 666 return tensorExps[e].kind == kInvariant; 667 } 668 669 Type Merger::inferType(unsigned e, Value src) { 670 // Obtain the destination type from the cast node. 671 Type dtp = tensorExps[e].val.getType(); 672 // Inspect source type. For vector types, apply the same 673 // vectorization to the destination type. 674 if (auto vtp = src.getType().dyn_cast<VectorType>()) 675 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 676 return dtp; 677 } 678 679 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 680 if (auto arg = v.dyn_cast<BlockArgument>()) { 681 unsigned argN = arg.getArgNumber(); 682 // Any argument of the generic op that is not marked as a scalar 683 // argument is considered a tensor, indexed by the implicit loop 684 // bounds. This includes rank-0 tensor arguments. 685 if (arg.getOwner()->getParentOp() == op) { 686 OpOperand *t = op.getInputAndOutputOperands()[argN]; 687 if (!op.isScalar(t)) 688 return addExp(kTensor, argN); 689 v = t->get(); // get scalar value 690 } 691 // Any other argument (marked as scalar argument for the generic op 692 // or belonging to an enveloping op) is considered invariant. 693 return addExp(kInvariant, v); 694 } 695 // Something defined outside is invariant. 696 Operation *def = v.getDefiningOp(); 697 if (def->getBlock() != &op.region().front()) 698 return addExp(kInvariant, v); 699 // Construct index operations. 700 if (def->getNumOperands() == 0) { 701 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 702 return addExp(kIndex, indexOp.dim()); 703 } 704 // Construct unary operations if subexpression can be built. 705 if (def->getNumOperands() == 1) { 706 auto x = buildTensorExp(op, def->getOperand(0)); 707 if (x.hasValue()) { 708 unsigned e = x.getValue(); 709 if (isa<math::AbsOp>(def)) 710 return addExp(kAbsF, e); 711 if (isa<math::CeilOp>(def)) 712 return addExp(kCeilF, e); 713 if (isa<math::FloorOp>(def)) 714 return addExp(kFloorF, e); 715 if (isa<arith::NegFOp>(def)) 716 return addExp(kNegF, e); // no negi in std 717 if (isa<arith::TruncFOp>(def)) 718 return addExp(kTruncF, e, v); 719 if (isa<arith::ExtFOp>(def)) 720 return addExp(kExtF, e, v); 721 if (isa<arith::FPToSIOp>(def)) 722 return addExp(kCastFS, e, v); 723 if (isa<arith::FPToUIOp>(def)) 724 return addExp(kCastFU, e, v); 725 if (isa<arith::SIToFPOp>(def)) 726 return addExp(kCastSF, e, v); 727 if (isa<arith::UIToFPOp>(def)) 728 return addExp(kCastUF, e, v); 729 if (isa<arith::ExtSIOp>(def)) 730 return addExp(kCastS, e, v); 731 if (isa<arith::ExtUIOp>(def)) 732 return addExp(kCastU, e, v); 733 if (isa<arith::IndexCastOp>(def)) 734 return addExp(kCastIdx, e, v); 735 if (isa<arith::TruncIOp>(def)) 736 return addExp(kTruncI, e, v); 737 if (isa<arith::BitcastOp>(def)) 738 return addExp(kBitCast, e, v); 739 if (isa<sparse_tensor::UnaryOp>(def)) 740 return addExp(kUnary, e, Value(), def); 741 } 742 } 743 // Construct binary operations if subexpressions can be built. 744 // See buildLattices() for an explanation of rejecting certain 745 // division and shift operations 746 if (def->getNumOperands() == 2) { 747 auto x = buildTensorExp(op, def->getOperand(0)); 748 auto y = buildTensorExp(op, def->getOperand(1)); 749 if (x.hasValue() && y.hasValue()) { 750 unsigned e0 = x.getValue(); 751 unsigned e1 = y.getValue(); 752 if (isa<arith::MulFOp>(def)) 753 return addExp(kMulF, e0, e1); 754 if (isa<arith::MulIOp>(def)) 755 return addExp(kMulI, e0, e1); 756 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 757 return addExp(kDivF, e0, e1); 758 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 759 return addExp(kDivS, e0, e1); 760 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 761 return addExp(kDivU, e0, e1); 762 if (isa<arith::AddFOp>(def)) 763 return addExp(kAddF, e0, e1); 764 if (isa<arith::AddIOp>(def)) 765 return addExp(kAddI, e0, e1); 766 if (isa<arith::SubFOp>(def)) 767 return addExp(kSubF, e0, e1); 768 if (isa<arith::SubIOp>(def)) 769 return addExp(kSubI, e0, e1); 770 if (isa<arith::AndIOp>(def)) 771 return addExp(kAndI, e0, e1); 772 if (isa<arith::OrIOp>(def)) 773 return addExp(kOrI, e0, e1); 774 if (isa<arith::XOrIOp>(def)) 775 return addExp(kXorI, e0, e1); 776 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 777 return addExp(kShrS, e0, e1); 778 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 779 return addExp(kShrU, e0, e1); 780 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 781 return addExp(kShlI, e0, e1); 782 if (isa<sparse_tensor::BinaryOp>(def)) 783 return addExp(kBinary, e0, e1, Value(), def); 784 } 785 } 786 // Cannot build. 787 return None; 788 } 789 790 static Value insertYieldOp(PatternRewriter &rewriter, Location loc, 791 Region ®ion, ValueRange vals) { 792 // Make a clone of overlap region. 793 Region tmpRegion; 794 BlockAndValueMapping mapper; 795 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 796 Block &clonedBlock = tmpRegion.front(); 797 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 798 // Merge cloned block and return yield value. 799 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 800 rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); 801 Value val = clonedYield.result(); 802 rewriter.eraseOp(clonedYield); 803 rewriter.eraseOp(placeholder); 804 return val; 805 } 806 807 static Value buildUnaryPresent(PatternRewriter &rewriter, Location loc, 808 Operation *op, Value v0) { 809 if (!v0) 810 // Empty input value must be propagated. 811 return Value(); 812 UnaryOp unop = cast<UnaryOp>(op); 813 Region &presentRegion = unop.presentRegion(); 814 if (presentRegion.empty()) 815 // Uninitialized Value() will be interpreted as missing data in the 816 // output. 817 return Value(); 818 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 819 } 820 821 static Value buildBinaryOverlap(PatternRewriter &rewriter, Location loc, 822 Operation *op, Value v0, Value v1) { 823 if (!v0 || !v1) 824 // Empty input values must be propagated. 825 return Value(); 826 BinaryOp binop = cast<BinaryOp>(op); 827 Region &overlapRegion = binop.overlapRegion(); 828 if (overlapRegion.empty()) 829 // Uninitialized Value() will be interpreted as missing data in the 830 // output. 831 return Value(); 832 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 833 } 834 835 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, 836 Value v0, Value v1) { 837 switch (tensorExps[e].kind) { 838 case kTensor: 839 case kInvariant: 840 case kIndex: 841 llvm_unreachable("unexpected non-op"); 842 // Unary ops. 843 case kAbsF: 844 return rewriter.create<math::AbsOp>(loc, v0); 845 case kCeilF: 846 return rewriter.create<math::CeilOp>(loc, v0); 847 case kFloorF: 848 return rewriter.create<math::FloorOp>(loc, v0); 849 case kNegF: 850 return rewriter.create<arith::NegFOp>(loc, v0); 851 case kNegI: // no negi in std 852 return rewriter.create<arith::SubIOp>( 853 loc, 854 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 855 rewriter.getZeroAttr(v0.getType())), 856 v0); 857 case kTruncF: 858 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 859 case kExtF: 860 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 861 case kCastFS: 862 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 863 case kCastFU: 864 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 865 case kCastSF: 866 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 867 case kCastUF: 868 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 869 case kCastS: 870 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 871 case kCastU: 872 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 873 case kCastIdx: 874 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 875 case kTruncI: 876 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 877 case kBitCast: 878 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 879 // Binary ops. 880 case kMulF: 881 return rewriter.create<arith::MulFOp>(loc, v0, v1); 882 case kMulI: 883 return rewriter.create<arith::MulIOp>(loc, v0, v1); 884 case kDivF: 885 return rewriter.create<arith::DivFOp>(loc, v0, v1); 886 case kDivS: 887 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 888 case kDivU: 889 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 890 case kAddF: 891 return rewriter.create<arith::AddFOp>(loc, v0, v1); 892 case kAddI: 893 return rewriter.create<arith::AddIOp>(loc, v0, v1); 894 case kSubF: 895 return rewriter.create<arith::SubFOp>(loc, v0, v1); 896 case kSubI: 897 return rewriter.create<arith::SubIOp>(loc, v0, v1); 898 case kAndI: 899 return rewriter.create<arith::AndIOp>(loc, v0, v1); 900 case kOrI: 901 return rewriter.create<arith::OrIOp>(loc, v0, v1); 902 case kXorI: 903 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 904 case kShrS: 905 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 906 case kShrU: 907 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 908 case kShlI: 909 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 910 // Semiring ops with custom logic. 911 case kBinaryBranch: 912 return insertYieldOp(rewriter, loc, 913 *tensorExps[e].op->getBlock()->getParent(), {v0}); 914 case kUnary: 915 return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); 916 case kBinary: 917 return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); 918 } 919 llvm_unreachable("unexpected expression kind in build"); 920 } 921 922 } // namespace sparse_tensor 923 } // namespace mlir 924