1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 18 namespace mlir { 19 namespace sparse_tensor { 20 21 //===----------------------------------------------------------------------===// 22 // Constructors. 23 //===----------------------------------------------------------------------===// 24 25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) 26 : kind(k), val(v), op(o) { 27 switch (kind) { 28 case kTensor: 29 assert(x != -1u && y == -1u && !v && !o); 30 tensor = x; 31 break; 32 case kInvariant: 33 assert(x == -1u && y == -1u && v && !o); 34 break; 35 case kIndex: 36 assert(x != -1u && y == -1u && !v && !o); 37 index = x; 38 break; 39 case kAbsF: 40 case kAbsC: 41 case kCeilF: 42 case kFloorF: 43 case kSqrtF: 44 case kExpm1F: 45 case kLog1pF: 46 case kLog1pC: 47 case kSinF: 48 case kSinC: 49 case kTanhF: 50 case kNegF: 51 case kNegC: 52 case kNegI: 53 case kCIm: 54 case kCRe: 55 assert(x != -1u && y == -1u && !v && !o); 56 children.e0 = x; 57 children.e1 = y; 58 break; 59 case kTruncF: 60 case kExtF: 61 case kCastFS: 62 case kCastFU: 63 case kCastSF: 64 case kCastUF: 65 case kCastS: 66 case kCastU: 67 case kCastIdx: 68 case kTruncI: 69 case kBitCast: 70 assert(x != -1u && y == -1u && v && !o); 71 children.e0 = x; 72 children.e1 = y; 73 break; 74 case kBinaryBranch: 75 assert(x != -1u && y == -1u && !v && o); 76 children.e0 = x; 77 children.e1 = y; 78 break; 79 case kUnary: 80 // No assertion on y can be made, as the branching paths involve both 81 // a unary (mapSet) and binary (takeDisj) pathway. 82 assert(x != -1u && !v && o); 83 children.e0 = x; 84 children.e1 = y; 85 break; 86 case kBinary: 87 assert(x != -1u && y != -1u && !v && o); 88 children.e0 = x; 89 children.e1 = y; 90 break; 91 default: 92 assert(x != -1u && y != -1u && !v && !o); 93 children.e0 = x; 94 children.e1 = y; 95 break; 96 } 97 } 98 99 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 100 : bits(n, false), simple(), exp(e) { 101 bits.set(b); 102 } 103 104 LatPoint::LatPoint(const BitVector &b, unsigned e) 105 : bits(b), simple(), exp(e) {} 106 107 //===----------------------------------------------------------------------===// 108 // Lattice methods. 109 //===----------------------------------------------------------------------===// 110 111 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, 112 Operation *op) { 113 unsigned e = tensorExps.size(); 114 tensorExps.push_back(TensorExp(k, e0, e1, v, op)); 115 return e; 116 } 117 118 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 119 assert(t < numTensors && i < numLoops); 120 unsigned p = latPoints.size(); 121 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 122 return p; 123 } 124 125 unsigned Merger::addSet() { 126 unsigned s = latSets.size(); 127 latSets.emplace_back(SmallVector<unsigned, 16>()); 128 return s; 129 } 130 131 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, 132 Operation *op) { 133 unsigned p = latPoints.size(); 134 BitVector nb = BitVector(latPoints[p0].bits); 135 nb |= latPoints[p1].bits; 136 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); 137 latPoints.push_back(LatPoint(nb, e)); 138 return p; 139 } 140 141 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 142 unsigned s = addSet(); 143 for (unsigned p0 : latSets[s0]) 144 for (unsigned p1 : latSets[s1]) 145 latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); 146 return s; 147 } 148 149 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 150 unsigned s = takeConj(kind, s0, s1, op); 151 // Followed by all in s0. 152 for (unsigned p : latSets[s0]) 153 latSets[s].push_back(p); 154 // Map binary 0-y to unary -y. 155 // TODO: move this if-else logic into buildLattices 156 if (kind == kSubF) 157 s1 = mapSet(kNegF, s1); 158 else if (kind == kSubC) 159 s1 = mapSet(kNegC, s1); 160 else if (kind == kSubI) 161 s1 = mapSet(kNegI, s1); 162 // Followed by all in s1. 163 for (unsigned p : latSets[s1]) 164 latSets[s].push_back(p); 165 return s; 166 } 167 168 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, 169 bool includeLeft, Kind ltrans, Operation *opleft, 170 bool includeRight, Kind rtrans, Operation *opright) { 171 unsigned s = takeConj(kind, s0, s1, orig); 172 // Left Region. 173 if (includeLeft) { 174 if (opleft) 175 s0 = mapSet(ltrans, s0, Value(), opleft); 176 for (unsigned p : latSets[s0]) 177 latSets[s].push_back(p); 178 } 179 // Right Region. 180 if (includeRight) { 181 if (opright) 182 s1 = mapSet(rtrans, s1, Value(), opright); 183 for (unsigned p : latSets[s1]) 184 latSets[s].push_back(p); 185 } 186 return s; 187 } 188 189 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { 190 assert(kAbsF <= kind && kind <= kUnary); 191 unsigned s = addSet(); 192 for (unsigned p : latSets[s0]) { 193 unsigned e = addExp(kind, latPoints[p].exp, v, op); 194 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 195 latSets[s].push_back(latPoints.size() - 1); 196 } 197 return s; 198 } 199 200 unsigned Merger::optimizeSet(unsigned s0) { 201 unsigned s = addSet(); 202 assert(!latSets[s0].empty()); 203 unsigned p0 = latSets[s0][0]; 204 for (unsigned p1 : latSets[s0]) { 205 bool add = true; 206 if (p0 != p1) { 207 // Is this a straightforward copy? 208 unsigned e = latPoints[p1].exp; 209 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 210 continue; 211 // Conjunction already covered? 212 for (unsigned p2 : latSets[s]) { 213 assert(!latGT(p1, p2)); // Lj => Li would be bad 214 if (onlyDenseDiff(p2, p1)) { 215 add = false; 216 break; 217 } 218 } 219 assert(!add || latGT(p0, p1)); 220 } 221 if (add) 222 latSets[s].push_back(p1); 223 } 224 for (unsigned p : latSets[s]) 225 latPoints[p].simple = simplifyCond(s, p); 226 return s; 227 } 228 229 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 230 // First determine if this lattice point is a *singleton*, i.e., 231 // the last point in a lattice, no other is less than this one. 232 bool isSingleton = true; 233 for (unsigned p1 : latSets[s0]) { 234 if (p0 != p1 && latGT(p0, p1)) { 235 isSingleton = false; 236 break; 237 } 238 } 239 // Now apply the two basic rules. 240 BitVector simple = latPoints[p0].bits; 241 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 242 for (unsigned b = 0, be = simple.size(); b < be; b++) { 243 if (simple[b] && !isDim(b, kSparse)) { 244 if (reset) 245 simple.reset(b); 246 reset = true; 247 } 248 } 249 return simple; 250 } 251 252 bool Merger::latGT(unsigned i, unsigned j) const { 253 const BitVector &bitsi = latPoints[i].bits; 254 const BitVector &bitsj = latPoints[j].bits; 255 assert(bitsi.size() == bitsj.size()); 256 if (bitsi.count() > bitsj.count()) { 257 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 258 if (bitsj[b] && !bitsi[b]) 259 return false; 260 return true; 261 } 262 return false; 263 } 264 265 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 266 BitVector tmp = latPoints[j].bits; 267 tmp ^= latPoints[i].bits; 268 return !hasAnyDimOf(tmp, kSparse); 269 } 270 271 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { 272 for (unsigned b = 0, be = bits.size(); b < be; b++) 273 if (bits[b] && isDim(b, d)) 274 return true; 275 return false; 276 } 277 278 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 279 switch (tensorExps[e].kind) { 280 case kTensor: 281 return tensorExps[e].tensor == t; 282 case kAbsF: 283 case kAbsC: 284 case kCeilF: 285 case kFloorF: 286 case kSqrtF: 287 case kExpm1F: 288 case kLog1pF: 289 case kLog1pC: 290 case kSinF: 291 case kSinC: 292 case kTanhF: 293 case kNegF: 294 case kNegC: 295 case kNegI: 296 case kTruncF: 297 case kExtF: 298 case kCastFS: 299 case kCastFU: 300 case kCastSF: 301 case kCastUF: 302 case kCastS: 303 case kCastU: 304 case kCastIdx: 305 case kTruncI: 306 case kCIm: 307 case kCRe: 308 case kBitCast: 309 return isSingleCondition(t, tensorExps[e].children.e0); 310 case kDivF: // note: x / c only 311 case kDivC: 312 case kDivS: 313 case kDivU: 314 assert(!maybeZero(tensorExps[e].children.e1)); 315 return isSingleCondition(t, tensorExps[e].children.e0); 316 case kShrS: // note: x >> inv only 317 case kShrU: 318 case kShlI: 319 assert(isInvariant(tensorExps[e].children.e1)); 320 return isSingleCondition(t, tensorExps[e].children.e0); 321 case kMulF: 322 case kMulC: 323 case kMulI: 324 case kAndI: 325 if (isSingleCondition(t, tensorExps[e].children.e0)) 326 return isSingleCondition(t, tensorExps[e].children.e1) || 327 isInvariant(tensorExps[e].children.e1); 328 if (isSingleCondition(t, tensorExps[e].children.e1)) 329 return isInvariant(tensorExps[e].children.e0); 330 return false; 331 case kAddF: 332 case kAddC: 333 case kAddI: 334 return isSingleCondition(t, tensorExps[e].children.e0) && 335 isSingleCondition(t, tensorExps[e].children.e1); 336 default: 337 return false; 338 } 339 } 340 341 #ifndef NDEBUG 342 343 //===----------------------------------------------------------------------===// 344 // Print methods (for debugging). 345 //===----------------------------------------------------------------------===// 346 347 static const char *kindToOpSymbol(Kind kind) { 348 switch (kind) { 349 case kTensor: 350 return "tensor"; 351 case kInvariant: 352 return "invariant"; 353 case kIndex: 354 return "index"; 355 case kAbsF: 356 case kAbsC: 357 return "abs"; 358 case kCeilF: 359 return "ceil"; 360 case kFloorF: 361 return "floor"; 362 case kSqrtF: 363 return "sqrt"; 364 case kExpm1F: 365 return "expm1"; 366 case kLog1pF: 367 case kLog1pC: 368 return "log1p"; 369 case kSinF: 370 case kSinC: 371 return "sin"; 372 case kTanhF: 373 return "tanh"; 374 case kNegF: 375 case kNegC: 376 case kNegI: 377 return "-"; 378 case kTruncF: 379 case kExtF: 380 case kCastFS: 381 case kCastFU: 382 case kCastSF: 383 case kCastUF: 384 case kCastS: 385 case kCastU: 386 case kCastIdx: 387 case kTruncI: 388 case kCIm: 389 return "complex.im"; 390 case kCRe: 391 return "complex.re"; 392 case kBitCast: 393 return "cast"; 394 case kBinaryBranch: 395 return "binary_branch"; 396 case kUnary: 397 return "unary"; 398 case kMulF: 399 case kMulC: 400 case kMulI: 401 return "*"; 402 case kDivF: 403 case kDivC: 404 case kDivS: 405 case kDivU: 406 return "/"; 407 case kAddF: 408 case kAddC: 409 case kAddI: 410 return "+"; 411 case kSubF: 412 case kSubC: 413 case kSubI: 414 return "-"; 415 case kAndI: 416 return "&"; 417 case kOrI: 418 return "|"; 419 case kXorI: 420 return "^"; 421 case kShrS: 422 return "a>>"; 423 case kShrU: 424 return ">>"; 425 case kShlI: 426 return "<<"; 427 case kBinary: 428 return "binary"; 429 } 430 llvm_unreachable("unexpected kind for symbol"); 431 } 432 433 void Merger::dumpExp(unsigned e) const { 434 switch (tensorExps[e].kind) { 435 case kTensor: 436 if (tensorExps[e].tensor == syntheticTensor) 437 llvm::dbgs() << "synthetic_"; 438 else if (tensorExps[e].tensor == outTensor) 439 llvm::dbgs() << "output_"; 440 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 441 break; 442 case kInvariant: 443 llvm::dbgs() << "invariant"; 444 break; 445 case kIndex: 446 llvm::dbgs() << "index_" << tensorExps[e].index; 447 break; 448 case kAbsF: 449 case kCeilF: 450 case kFloorF: 451 case kSqrtF: 452 case kExpm1F: 453 case kLog1pF: 454 case kSinF: 455 case kTanhF: 456 case kNegF: 457 case kNegI: 458 case kTruncF: 459 case kExtF: 460 case kCastFS: 461 case kCastFU: 462 case kCastSF: 463 case kCastUF: 464 case kCastS: 465 case kCastU: 466 case kCastIdx: 467 case kTruncI: 468 case kBitCast: 469 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 470 dumpExp(tensorExps[e].children.e0); 471 break; 472 default: 473 llvm::dbgs() << "("; 474 dumpExp(tensorExps[e].children.e0); 475 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 476 dumpExp(tensorExps[e].children.e1); 477 llvm::dbgs() << ")"; 478 } 479 } 480 481 void Merger::dumpLat(unsigned p) const { 482 llvm::dbgs() << "lat("; 483 dumpBits(latPoints[p].bits); 484 llvm::dbgs() << " :"; 485 dumpBits(latPoints[p].simple); 486 llvm::dbgs() << " : "; 487 dumpExp(latPoints[p].exp); 488 llvm::dbgs() << " )\n"; 489 } 490 491 void Merger::dumpSet(unsigned s) const { 492 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 493 for (unsigned p : latSets[s]) { 494 llvm::dbgs() << " "; 495 dumpLat(p); 496 } 497 llvm::dbgs() << "}\n"; 498 } 499 500 void Merger::dumpBits(const BitVector &bits) const { 501 for (unsigned b = 0, be = bits.size(); b < be; b++) { 502 if (bits[b]) { 503 unsigned t = tensor(b); 504 unsigned i = index(b); 505 llvm::dbgs() << " i_" << t << "_" << i << "_"; 506 switch (dims[t][i]) { 507 case kSparse: 508 llvm::dbgs() << "S"; 509 break; 510 case kDense: 511 llvm::dbgs() << "D"; 512 break; 513 case kSingle: 514 llvm::dbgs() << "T"; 515 break; 516 case kUndef: 517 llvm::dbgs() << "U"; 518 break; 519 } 520 } 521 } 522 } 523 524 #endif // NDEBUG 525 526 //===----------------------------------------------------------------------===// 527 // Builder methods. 528 //===----------------------------------------------------------------------===// 529 530 unsigned Merger::buildLattices(unsigned e, unsigned i) { 531 Kind kind = tensorExps[e].kind; 532 switch (kind) { 533 case kTensor: 534 case kInvariant: 535 case kIndex: { 536 // Either the index is really used in the tensor expression, or it is 537 // set to the undefined index in that dimension. An invariant expression, 538 // a proper index value, and a truly dynamic sparse output tensor are set 539 // to a synthetic tensor with undefined indices only to ensure the 540 // iteration space is not skipped as a result of their contents. 541 unsigned s = addSet(); 542 unsigned t = syntheticTensor; 543 if (kind == kTensor) { 544 t = tensorExps[e].tensor; 545 if (hasSparseOut && t == outTensor) 546 t = syntheticTensor; 547 } 548 latSets[s].push_back(addLat(t, i, e)); 549 return s; 550 } 551 case kAbsF: 552 case kAbsC: 553 case kCeilF: 554 case kCIm: 555 case kCRe: 556 case kFloorF: 557 case kSqrtF: 558 case kExpm1F: 559 case kLog1pF: 560 case kLog1pC: 561 case kSinF: 562 case kSinC: 563 case kTanhF: 564 case kNegF: 565 case kNegC: 566 case kNegI: 567 case kTruncF: 568 case kExtF: 569 case kCastFS: 570 case kCastFU: 571 case kCastSF: 572 case kCastUF: 573 case kCastS: 574 case kCastU: 575 case kCastIdx: 576 case kTruncI: 577 case kBitCast: 578 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 579 // lattice set of the operand through the operator into a new set. 580 // 581 // -y|!y | y | 582 // --+---+---+ 583 // | 0 |-y | 584 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 585 tensorExps[e].val); 586 case kBinaryBranch: 587 // The left or right half of a binary operation which has already 588 // been split into separate operations for each region. 589 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), 590 tensorExps[e].op); 591 case kUnary: 592 // A custom unary operation. 593 // 594 // op y| !y | y | 595 // ----+----------+------------+ 596 // | absent() | present(y) | 597 { 598 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 599 UnaryOp unop = cast<UnaryOp>(tensorExps[e].op); 600 Region &absentRegion = unop.absentRegion(); 601 602 if (absentRegion.empty()) { 603 // Simple mapping over existing values. 604 return mapSet(kind, child0, Value(), unop); 605 } // Use a disjunction with `unop` on the left and the absent value as an 606 // invariant on the right. 607 Block &absentBlock = absentRegion.front(); 608 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 609 Value absentVal = absentYield.result(); 610 unsigned rhs = addExp(kInvariant, absentVal); 611 return takeDisj(kind, child0, buildLattices(rhs, i), unop); 612 } 613 case kMulF: 614 case kMulC: 615 case kMulI: 616 case kAndI: 617 // A multiplicative operation only needs to be performed 618 // for the conjunction of sparse iteration spaces. 619 // 620 // x*y|!y | y | 621 // ---+---+---+ 622 // !x | 0 | 0 | 623 // x | 0 |x*y| 624 // 625 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 626 return takeConj(kind, // take binary conjunction 627 buildLattices(tensorExps[e].children.e0, i), 628 buildLattices(tensorExps[e].children.e1, i)); 629 case kDivF: 630 case kDivC: 631 case kDivS: 632 case kDivU: 633 // A division is tricky, since 0/0, 0/c, c/0 all have 634 // specific outcomes for floating-point and integers. 635 // Thus, we need to traverse the full iteration space. 636 // 637 // x/y|!y | y | 638 // ---+---+---+ 639 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 640 // x |x/0|x/y| INT: x/0=exception for any x 641 // 642 // TODO: for now we "fixed" this by only accepting x/c cases 643 // during expression building, so that the conjunction 644 // rules applies (viz. x/c = x*(1/c) as far as lattice 645 // construction is concerned). 646 assert(!maybeZero(tensorExps[e].children.e1)); 647 return takeConj(kind, // take binary conjunction 648 buildLattices(tensorExps[e].children.e0, i), 649 buildLattices(tensorExps[e].children.e1, i)); 650 case kAddF: 651 case kAddC: 652 case kAddI: 653 case kSubF: 654 case kSubC: 655 case kSubI: 656 case kOrI: 657 case kXorI: 658 // An additive operation needs to be performed 659 // for the disjunction of sparse iteration spaces. 660 // 661 // x+y|!y | y | x-y|!y | y | 662 // ---+---+---+ ---+---+---+ 663 // !x | 0 | y | !x | 0 |-y | 664 // x | x |x+y| x | x |x-y| 665 return takeDisj(kind, // take binary disjunction 666 buildLattices(tensorExps[e].children.e0, i), 667 buildLattices(tensorExps[e].children.e1, i)); 668 case kShrS: 669 case kShrU: 670 case kShlI: 671 // A shift operation by an invariant amount (viz. tensor expressions 672 // can only occur at the left-hand-side of the operator) can be handled 673 // with the conjuction rule. 674 assert(isInvariant(tensorExps[e].children.e1)); 675 return takeConj(kind, // take binary conjunction 676 buildLattices(tensorExps[e].children.e0, i), 677 buildLattices(tensorExps[e].children.e1, i)); 678 case kBinary: 679 // A custom binary operation. 680 // 681 // x op y| !y | y | 682 // ------+---------+--------------+ 683 // !x | empty | right(y) | 684 // x | left(x) | overlap(x,y) | 685 { 686 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 687 unsigned child1 = buildLattices(tensorExps[e].children.e1, i); 688 BinaryOp binop = cast<BinaryOp>(tensorExps[e].op); 689 Region &leftRegion = binop.leftRegion(); 690 Region &rightRegion = binop.rightRegion(); 691 // Left Region. 692 Operation *leftYield = nullptr; 693 if (!leftRegion.empty()) { 694 Block &leftBlock = leftRegion.front(); 695 leftYield = leftBlock.getTerminator(); 696 } 697 // Right Region. 698 Operation *rightYield = nullptr; 699 if (!rightRegion.empty()) { 700 Block &rightBlock = rightRegion.front(); 701 rightYield = rightBlock.getTerminator(); 702 } 703 bool includeLeft = binop.left_identity() || !leftRegion.empty(); 704 bool includeRight = binop.right_identity() || !rightRegion.empty(); 705 return takeCombi(kBinary, child0, child1, binop, includeLeft, 706 kBinaryBranch, leftYield, includeRight, kBinaryBranch, 707 rightYield); 708 } 709 } 710 llvm_unreachable("unexpected expression kind"); 711 } 712 713 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 714 Operation *yield = op.region().front().getTerminator(); 715 return buildTensorExp(op, yield->getOperand(0)); 716 } 717 718 /// Only returns false if we are certain this is a nonzero. 719 bool Merger::maybeZero(unsigned e) const { 720 if (tensorExps[e].kind == kInvariant) { 721 if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) { 722 ArrayAttr arrayAttr = c.getValue(); 723 return arrayAttr[0].cast<FloatAttr>().getValue().isZero() && 724 arrayAttr[0].cast<FloatAttr>().getValue().isZero(); 725 } 726 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 727 return c.value() == 0; 728 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 729 return c.value().isZero(); 730 } 731 return true; 732 } 733 734 bool Merger::isInvariant(unsigned e) const { 735 return tensorExps[e].kind == kInvariant; 736 } 737 738 Type Merger::inferType(unsigned e, Value src) { 739 // Obtain the destination type from the cast node. 740 Type dtp = tensorExps[e].val.getType(); 741 // Inspect source type. For vector types, apply the same 742 // vectorization to the destination type. 743 if (auto vtp = src.getType().dyn_cast<VectorType>()) 744 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 745 return dtp; 746 } 747 748 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 749 if (auto arg = v.dyn_cast<BlockArgument>()) { 750 unsigned argN = arg.getArgNumber(); 751 // Any argument of the generic op that is not marked as a scalar 752 // argument is considered a tensor, indexed by the implicit loop 753 // bounds. This includes rank-0 tensor arguments. 754 if (arg.getOwner()->getParentOp() == op) { 755 OpOperand *t = op.getInputAndOutputOperands()[argN]; 756 if (!op.isScalar(t)) 757 return addExp(kTensor, argN); 758 v = t->get(); // get scalar value 759 } 760 // Any other argument (marked as scalar argument for the generic op 761 // or belonging to an enveloping op) is considered invariant. 762 return addExp(kInvariant, v); 763 } 764 // Something defined outside is invariant. 765 Operation *def = v.getDefiningOp(); 766 if (def->getBlock() != &op.region().front()) 767 return addExp(kInvariant, v); 768 // Construct index operations. 769 if (def->getNumOperands() == 0) { 770 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 771 return addExp(kIndex, indexOp.dim()); 772 } 773 // Construct unary operations if subexpression can be built. 774 if (def->getNumOperands() == 1) { 775 auto x = buildTensorExp(op, def->getOperand(0)); 776 if (x.hasValue()) { 777 unsigned e = x.getValue(); 778 if (isa<math::AbsOp>(def)) 779 return addExp(kAbsF, e); 780 if (isa<complex::AbsOp>(def)) 781 return addExp(kAbsC, e); 782 if (isa<math::CeilOp>(def)) 783 return addExp(kCeilF, e); 784 if (isa<math::FloorOp>(def)) 785 return addExp(kFloorF, e); 786 if (isa<math::SqrtOp>(def)) 787 return addExp(kSqrtF, e); 788 if (isa<math::ExpM1Op>(def)) 789 return addExp(kExpm1F, e); 790 if (isa<math::Log1pOp>(def)) 791 return addExp(kLog1pF, e); 792 if (isa<complex::Log1pOp>(def)) 793 return addExp(kLog1pC, e); 794 if (isa<math::SinOp>(def)) 795 return addExp(kSinF, e); 796 if (isa<complex::SinOp>(def)) 797 return addExp(kSinC, e); 798 if (isa<math::TanhOp>(def)) 799 return addExp(kTanhF, e); 800 if (isa<arith::NegFOp>(def)) 801 return addExp(kNegF, e); // no negi in std 802 if (isa<complex::NegOp>(def)) 803 return addExp(kNegC, e); 804 if (isa<arith::TruncFOp>(def)) 805 return addExp(kTruncF, e, v); 806 if (isa<arith::ExtFOp>(def)) 807 return addExp(kExtF, e, v); 808 if (isa<arith::FPToSIOp>(def)) 809 return addExp(kCastFS, e, v); 810 if (isa<arith::FPToUIOp>(def)) 811 return addExp(kCastFU, e, v); 812 if (isa<arith::SIToFPOp>(def)) 813 return addExp(kCastSF, e, v); 814 if (isa<arith::UIToFPOp>(def)) 815 return addExp(kCastUF, e, v); 816 if (isa<arith::ExtSIOp>(def)) 817 return addExp(kCastS, e, v); 818 if (isa<arith::ExtUIOp>(def)) 819 return addExp(kCastU, e, v); 820 if (isa<arith::IndexCastOp>(def)) 821 return addExp(kCastIdx, e, v); 822 if (isa<arith::TruncIOp>(def)) 823 return addExp(kTruncI, e, v); 824 if (isa<complex::ImOp>(def)) 825 return addExp(kCIm, e); 826 if (isa<complex::ReOp>(def)) 827 return addExp(kCRe, e); 828 if (isa<arith::BitcastOp>(def)) 829 return addExp(kBitCast, e, v); 830 if (isa<sparse_tensor::UnaryOp>(def)) 831 return addExp(kUnary, e, Value(), def); 832 } 833 } 834 // Construct binary operations if subexpressions can be built. 835 // See buildLattices() for an explanation of rejecting certain 836 // division and shift operations 837 if (def->getNumOperands() == 2) { 838 auto x = buildTensorExp(op, def->getOperand(0)); 839 auto y = buildTensorExp(op, def->getOperand(1)); 840 if (x.hasValue() && y.hasValue()) { 841 unsigned e0 = x.getValue(); 842 unsigned e1 = y.getValue(); 843 if (isa<arith::MulFOp>(def)) 844 return addExp(kMulF, e0, e1); 845 if (isa<complex::MulOp>(def)) 846 return addExp(kMulC, e0, e1); 847 if (isa<arith::MulIOp>(def)) 848 return addExp(kMulI, e0, e1); 849 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 850 return addExp(kDivF, e0, e1); 851 if (isa<complex::DivOp>(def) && !maybeZero(e1)) 852 return addExp(kDivC, e0, e1); 853 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 854 return addExp(kDivS, e0, e1); 855 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 856 return addExp(kDivU, e0, e1); 857 if (isa<arith::AddFOp>(def)) 858 return addExp(kAddF, e0, e1); 859 if (isa<complex::AddOp>(def)) 860 return addExp(kAddC, e0, e1); 861 if (isa<arith::AddIOp>(def)) 862 return addExp(kAddI, e0, e1); 863 if (isa<arith::SubFOp>(def)) 864 return addExp(kSubF, e0, e1); 865 if (isa<complex::SubOp>(def)) 866 return addExp(kSubC, e0, e1); 867 if (isa<arith::SubIOp>(def)) 868 return addExp(kSubI, e0, e1); 869 if (isa<arith::AndIOp>(def)) 870 return addExp(kAndI, e0, e1); 871 if (isa<arith::OrIOp>(def)) 872 return addExp(kOrI, e0, e1); 873 if (isa<arith::XOrIOp>(def)) 874 return addExp(kXorI, e0, e1); 875 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 876 return addExp(kShrS, e0, e1); 877 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 878 return addExp(kShrU, e0, e1); 879 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 880 return addExp(kShlI, e0, e1); 881 if (isa<sparse_tensor::BinaryOp>(def)) 882 return addExp(kBinary, e0, e1, Value(), def); 883 } 884 } 885 // Cannot build. 886 return None; 887 } 888 889 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 890 ValueRange vals) { 891 // Make a clone of overlap region. 892 Region tmpRegion; 893 BlockAndValueMapping mapper; 894 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 895 Block &clonedBlock = tmpRegion.front(); 896 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 897 // Merge cloned block and return yield value. 898 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 899 rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); 900 Value val = clonedYield.result(); 901 rewriter.eraseOp(clonedYield); 902 rewriter.eraseOp(placeholder); 903 return val; 904 } 905 906 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 907 Operation *op, Value v0) { 908 if (!v0) 909 // Empty input value must be propagated. 910 return Value(); 911 UnaryOp unop = cast<UnaryOp>(op); 912 Region &presentRegion = unop.presentRegion(); 913 if (presentRegion.empty()) 914 // Uninitialized Value() will be interpreted as missing data in the 915 // output. 916 return Value(); 917 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 918 } 919 920 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 921 Operation *op, Value v0, Value v1) { 922 if (!v0 || !v1) 923 // Empty input values must be propagated. 924 return Value(); 925 BinaryOp binop = cast<BinaryOp>(op); 926 Region &overlapRegion = binop.overlapRegion(); 927 if (overlapRegion.empty()) 928 // Uninitialized Value() will be interpreted as missing data in the 929 // output. 930 return Value(); 931 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 932 } 933 934 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, 935 Value v0, Value v1) { 936 switch (tensorExps[e].kind) { 937 case kTensor: 938 case kInvariant: 939 case kIndex: 940 llvm_unreachable("unexpected non-op"); 941 // Unary ops. 942 case kAbsF: 943 return rewriter.create<math::AbsOp>(loc, v0); 944 case kAbsC: { 945 auto type = v0.getType().template cast<ComplexType>(); 946 auto eltType = type.getElementType().template cast<FloatType>(); 947 return rewriter.create<complex::AbsOp>(loc, eltType, v0); 948 } 949 case kCeilF: 950 return rewriter.create<math::CeilOp>(loc, v0); 951 case kFloorF: 952 return rewriter.create<math::FloorOp>(loc, v0); 953 case kSqrtF: 954 return rewriter.create<math::SqrtOp>(loc, v0); 955 case kExpm1F: 956 return rewriter.create<math::ExpM1Op>(loc, v0); 957 case kLog1pF: 958 return rewriter.create<math::Log1pOp>(loc, v0); 959 case kLog1pC: 960 return rewriter.create<complex::Log1pOp>(loc, v0); 961 case kSinF: 962 return rewriter.create<math::SinOp>(loc, v0); 963 case kSinC: 964 return rewriter.create<complex::SinOp>(loc, v0); 965 case kTanhF: 966 return rewriter.create<math::TanhOp>(loc, v0); 967 case kNegF: 968 return rewriter.create<arith::NegFOp>(loc, v0); 969 case kNegC: 970 return rewriter.create<complex::NegOp>(loc, v0); 971 case kNegI: // no negi in std 972 return rewriter.create<arith::SubIOp>( 973 loc, 974 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 975 rewriter.getZeroAttr(v0.getType())), 976 v0); 977 case kTruncF: 978 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 979 case kExtF: 980 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 981 case kCastFS: 982 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 983 case kCastFU: 984 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 985 case kCastSF: 986 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 987 case kCastUF: 988 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 989 case kCastS: 990 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 991 case kCastU: 992 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 993 case kCastIdx: 994 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 995 case kTruncI: 996 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 997 case kCIm: 998 case kCRe: { 999 auto type = v0.getType().template cast<ComplexType>(); 1000 auto eltType = type.getElementType().template cast<FloatType>(); 1001 if (tensorExps[e].kind == kCIm) 1002 return rewriter.create<complex::ImOp>(loc, eltType, v0); 1003 1004 return rewriter.create<complex::ReOp>(loc, eltType, v0); 1005 } 1006 case kBitCast: 1007 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 1008 // Binary ops. 1009 case kMulF: 1010 return rewriter.create<arith::MulFOp>(loc, v0, v1); 1011 case kMulC: 1012 return rewriter.create<complex::MulOp>(loc, v0, v1); 1013 case kMulI: 1014 return rewriter.create<arith::MulIOp>(loc, v0, v1); 1015 case kDivF: 1016 return rewriter.create<arith::DivFOp>(loc, v0, v1); 1017 case kDivC: 1018 return rewriter.create<complex::DivOp>(loc, v0, v1); 1019 case kDivS: 1020 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 1021 case kDivU: 1022 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 1023 case kAddF: 1024 return rewriter.create<arith::AddFOp>(loc, v0, v1); 1025 case kAddC: 1026 return rewriter.create<complex::AddOp>(loc, v0, v1); 1027 case kAddI: 1028 return rewriter.create<arith::AddIOp>(loc, v0, v1); 1029 case kSubF: 1030 return rewriter.create<arith::SubFOp>(loc, v0, v1); 1031 case kSubC: 1032 return rewriter.create<complex::SubOp>(loc, v0, v1); 1033 case kSubI: 1034 return rewriter.create<arith::SubIOp>(loc, v0, v1); 1035 case kAndI: 1036 return rewriter.create<arith::AndIOp>(loc, v0, v1); 1037 case kOrI: 1038 return rewriter.create<arith::OrIOp>(loc, v0, v1); 1039 case kXorI: 1040 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 1041 case kShrS: 1042 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 1043 case kShrU: 1044 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 1045 case kShlI: 1046 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 1047 // Semiring ops with custom logic. 1048 case kBinaryBranch: 1049 return insertYieldOp(rewriter, loc, 1050 *tensorExps[e].op->getBlock()->getParent(), {v0}); 1051 case kUnary: 1052 return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); 1053 case kBinary: 1054 return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); 1055 } 1056 llvm_unreachable("unexpected expression kind in build"); 1057 } 1058 1059 } // namespace sparse_tensor 1060 } // namespace mlir 1061