1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Complex/IR/Complex.h" 12 #include "mlir/Dialect/Math/IR/Math.h" 13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" 14 15 #include "mlir/IR/Operation.h" 16 #include "llvm/Support/Debug.h" 17 18 namespace mlir { 19 namespace sparse_tensor { 20 21 //===----------------------------------------------------------------------===// 22 // Constructors. 23 //===----------------------------------------------------------------------===// 24 25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o) 26 : kind(k), val(v), op(o) { 27 switch (kind) { 28 case kTensor: 29 assert(x != -1u && y == -1u && !v && !o); 30 tensor = x; 31 break; 32 case kInvariant: 33 assert(x == -1u && y == -1u && v && !o); 34 break; 35 case kIndex: 36 assert(x != -1u && y == -1u && !v && !o); 37 index = x; 38 break; 39 case kAbsF: 40 case kCeilF: 41 case kFloorF: 42 case kSqrtF: 43 case kExpm1F: 44 case kLog1pF: 45 case kSinF: 46 case kTanhF: 47 case kNegF: 48 case kNegI: 49 case kCIm: 50 case kCRe: 51 assert(x != -1u && y == -1u && !v && !o); 52 children.e0 = x; 53 children.e1 = y; 54 break; 55 case kTruncF: 56 case kExtF: 57 case kCastFS: 58 case kCastFU: 59 case kCastSF: 60 case kCastUF: 61 case kCastS: 62 case kCastU: 63 case kCastIdx: 64 case kTruncI: 65 case kBitCast: 66 assert(x != -1u && y == -1u && v && !o); 67 children.e0 = x; 68 children.e1 = y; 69 break; 70 case kBinaryBranch: 71 assert(x != -1u && y == -1u && !v && o); 72 children.e0 = x; 73 children.e1 = y; 74 break; 75 case kUnary: 76 // No assertion on y can be made, as the branching paths involve both 77 // a unary (mapSet) and binary (takeDisj) pathway. 78 assert(x != -1u && !v && o); 79 children.e0 = x; 80 children.e1 = y; 81 break; 82 case kBinary: 83 assert(x != -1u && y != -1u && !v && o); 84 children.e0 = x; 85 children.e1 = y; 86 break; 87 default: 88 assert(x != -1u && y != -1u && !v && !o); 89 children.e0 = x; 90 children.e1 = y; 91 break; 92 } 93 } 94 95 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 96 : bits(n, false), simple(), exp(e) { 97 bits.set(b); 98 } 99 100 LatPoint::LatPoint(const BitVector &b, unsigned e) 101 : bits(b), simple(), exp(e) {} 102 103 //===----------------------------------------------------------------------===// 104 // Lattice methods. 105 //===----------------------------------------------------------------------===// 106 107 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v, 108 Operation *op) { 109 unsigned e = tensorExps.size(); 110 tensorExps.push_back(TensorExp(k, e0, e1, v, op)); 111 return e; 112 } 113 114 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 115 assert(t < numTensors && i < numLoops); 116 unsigned p = latPoints.size(); 117 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 118 return p; 119 } 120 121 unsigned Merger::addSet() { 122 unsigned s = latSets.size(); 123 latSets.emplace_back(SmallVector<unsigned, 16>()); 124 return s; 125 } 126 127 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1, 128 Operation *op) { 129 unsigned p = latPoints.size(); 130 BitVector nb = BitVector(latPoints[p0].bits); 131 nb |= latPoints[p1].bits; 132 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op); 133 latPoints.push_back(LatPoint(nb, e)); 134 return p; 135 } 136 137 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 138 unsigned s = addSet(); 139 for (unsigned p0 : latSets[s0]) 140 for (unsigned p1 : latSets[s1]) 141 latSets[s].push_back(conjLatPoint(kind, p0, p1, op)); 142 return s; 143 } 144 145 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) { 146 unsigned s = takeConj(kind, s0, s1, op); 147 // Followed by all in s0. 148 for (unsigned p : latSets[s0]) 149 latSets[s].push_back(p); 150 // Map binary 0-y to unary -y. 151 // TODO: move this if-else logic into buildLattices 152 if (kind == kSubF) 153 s1 = mapSet(kNegF, s1); 154 else if (kind == kSubI) 155 s1 = mapSet(kNegI, s1); 156 // Followed by all in s1. 157 for (unsigned p : latSets[s1]) 158 latSets[s].push_back(p); 159 return s; 160 } 161 162 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig, 163 bool includeLeft, Kind ltrans, Operation *opleft, 164 bool includeRight, Kind rtrans, Operation *opright) { 165 unsigned s = takeConj(kind, s0, s1, orig); 166 // Left Region. 167 if (includeLeft) { 168 if (opleft) 169 s0 = mapSet(ltrans, s0, Value(), opleft); 170 for (unsigned p : latSets[s0]) 171 latSets[s].push_back(p); 172 } 173 // Right Region. 174 if (includeRight) { 175 if (opright) 176 s1 = mapSet(rtrans, s1, Value(), opright); 177 for (unsigned p : latSets[s1]) 178 latSets[s].push_back(p); 179 } 180 return s; 181 } 182 183 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) { 184 assert(kAbsF <= kind && kind <= kUnary); 185 unsigned s = addSet(); 186 for (unsigned p : latSets[s0]) { 187 unsigned e = addExp(kind, latPoints[p].exp, v, op); 188 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 189 latSets[s].push_back(latPoints.size() - 1); 190 } 191 return s; 192 } 193 194 unsigned Merger::optimizeSet(unsigned s0) { 195 unsigned s = addSet(); 196 assert(!latSets[s0].empty()); 197 unsigned p0 = latSets[s0][0]; 198 for (unsigned p1 : latSets[s0]) { 199 bool add = true; 200 if (p0 != p1) { 201 // Is this a straightforward copy? 202 unsigned e = latPoints[p1].exp; 203 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 204 continue; 205 // Conjunction already covered? 206 for (unsigned p2 : latSets[s]) { 207 assert(!latGT(p1, p2)); // Lj => Li would be bad 208 if (onlyDenseDiff(p2, p1)) { 209 add = false; 210 break; 211 } 212 } 213 assert(!add || latGT(p0, p1)); 214 } 215 if (add) 216 latSets[s].push_back(p1); 217 } 218 for (unsigned p : latSets[s]) 219 latPoints[p].simple = simplifyCond(s, p); 220 return s; 221 } 222 223 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 224 // First determine if this lattice point is a *singleton*, i.e., 225 // the last point in a lattice, no other is less than this one. 226 bool isSingleton = true; 227 for (unsigned p1 : latSets[s0]) { 228 if (p0 != p1 && latGT(p0, p1)) { 229 isSingleton = false; 230 break; 231 } 232 } 233 // Now apply the two basic rules. 234 BitVector simple = latPoints[p0].bits; 235 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 236 for (unsigned b = 0, be = simple.size(); b < be; b++) { 237 if (simple[b] && !isDim(b, kSparse)) { 238 if (reset) 239 simple.reset(b); 240 reset = true; 241 } 242 } 243 return simple; 244 } 245 246 bool Merger::latGT(unsigned i, unsigned j) const { 247 const BitVector &bitsi = latPoints[i].bits; 248 const BitVector &bitsj = latPoints[j].bits; 249 assert(bitsi.size() == bitsj.size()); 250 if (bitsi.count() > bitsj.count()) { 251 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 252 if (bitsj[b] && !bitsi[b]) 253 return false; 254 return true; 255 } 256 return false; 257 } 258 259 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 260 BitVector tmp = latPoints[j].bits; 261 tmp ^= latPoints[i].bits; 262 return !hasAnyDimOf(tmp, kSparse); 263 } 264 265 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { 266 for (unsigned b = 0, be = bits.size(); b < be; b++) 267 if (bits[b] && isDim(b, d)) 268 return true; 269 return false; 270 } 271 272 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 273 switch (tensorExps[e].kind) { 274 case kTensor: 275 return tensorExps[e].tensor == t; 276 case kAbsF: 277 case kCeilF: 278 case kFloorF: 279 case kSqrtF: 280 case kExpm1F: 281 case kLog1pF: 282 case kSinF: 283 case kTanhF: 284 case kNegF: 285 case kNegI: 286 case kTruncF: 287 case kExtF: 288 case kCastFS: 289 case kCastFU: 290 case kCastSF: 291 case kCastUF: 292 case kCastS: 293 case kCastU: 294 case kCastIdx: 295 case kTruncI: 296 case kCIm: 297 case kCRe: 298 case kBitCast: 299 return isSingleCondition(t, tensorExps[e].children.e0); 300 case kDivF: // note: x / c only 301 case kDivS: 302 case kDivU: 303 assert(!maybeZero(tensorExps[e].children.e1)); 304 return isSingleCondition(t, tensorExps[e].children.e0); 305 case kShrS: // note: x >> inv only 306 case kShrU: 307 case kShlI: 308 assert(isInvariant(tensorExps[e].children.e1)); 309 return isSingleCondition(t, tensorExps[e].children.e0); 310 case kMulF: 311 case kMulC: 312 case kMulI: 313 case kAndI: 314 if (isSingleCondition(t, tensorExps[e].children.e0)) 315 return isSingleCondition(t, tensorExps[e].children.e1) || 316 isInvariant(tensorExps[e].children.e1); 317 if (isSingleCondition(t, tensorExps[e].children.e1)) 318 return isInvariant(tensorExps[e].children.e0); 319 return false; 320 case kAddF: 321 case kAddC: 322 case kAddI: 323 return isSingleCondition(t, tensorExps[e].children.e0) && 324 isSingleCondition(t, tensorExps[e].children.e1); 325 default: 326 return false; 327 } 328 } 329 330 #ifndef NDEBUG 331 332 //===----------------------------------------------------------------------===// 333 // Print methods (for debugging). 334 //===----------------------------------------------------------------------===// 335 336 static const char *kindToOpSymbol(Kind kind) { 337 switch (kind) { 338 case kTensor: 339 return "tensor"; 340 case kInvariant: 341 return "invariant"; 342 case kIndex: 343 return "index"; 344 case kAbsF: 345 return "abs"; 346 case kCeilF: 347 return "ceil"; 348 case kFloorF: 349 return "floor"; 350 case kSqrtF: 351 return "sqrt"; 352 case kExpm1F: 353 return "expm1"; 354 case kLog1pF: 355 return "log1p"; 356 case kSinF: 357 return "sin"; 358 case kTanhF: 359 return "tanh"; 360 case kNegF: 361 return "-"; 362 case kNegI: 363 return "-"; 364 case kTruncF: 365 case kExtF: 366 case kCastFS: 367 case kCastFU: 368 case kCastSF: 369 case kCastUF: 370 case kCastS: 371 case kCastU: 372 case kCastIdx: 373 case kTruncI: 374 case kCIm: 375 return "complex.im"; 376 case kCRe: 377 return "complex.re"; 378 case kBitCast: 379 return "cast"; 380 case kBinaryBranch: 381 return "binary_branch"; 382 case kUnary: 383 return "unary"; 384 case kMulF: 385 case kMulC: 386 case kMulI: 387 return "*"; 388 case kDivF: 389 case kDivS: 390 case kDivU: 391 return "/"; 392 case kAddF: 393 case kAddC: 394 case kAddI: 395 return "+"; 396 case kSubF: 397 case kSubI: 398 return "-"; 399 case kAndI: 400 return "&"; 401 case kOrI: 402 return "|"; 403 case kXorI: 404 return "^"; 405 case kShrS: 406 return "a>>"; 407 case kShrU: 408 return ">>"; 409 case kShlI: 410 return "<<"; 411 case kBinary: 412 return "binary"; 413 } 414 llvm_unreachable("unexpected kind for symbol"); 415 } 416 417 void Merger::dumpExp(unsigned e) const { 418 switch (tensorExps[e].kind) { 419 case kTensor: 420 if (tensorExps[e].tensor == syntheticTensor) 421 llvm::dbgs() << "synthetic_"; 422 else if (tensorExps[e].tensor == outTensor) 423 llvm::dbgs() << "output_"; 424 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 425 break; 426 case kInvariant: 427 llvm::dbgs() << "invariant"; 428 break; 429 case kIndex: 430 llvm::dbgs() << "index_" << tensorExps[e].index; 431 break; 432 case kAbsF: 433 case kCeilF: 434 case kFloorF: 435 case kSqrtF: 436 case kExpm1F: 437 case kLog1pF: 438 case kSinF: 439 case kTanhF: 440 case kNegF: 441 case kNegI: 442 case kTruncF: 443 case kExtF: 444 case kCastFS: 445 case kCastFU: 446 case kCastSF: 447 case kCastUF: 448 case kCastS: 449 case kCastU: 450 case kCastIdx: 451 case kTruncI: 452 case kBitCast: 453 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 454 dumpExp(tensorExps[e].children.e0); 455 break; 456 default: 457 llvm::dbgs() << "("; 458 dumpExp(tensorExps[e].children.e0); 459 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 460 dumpExp(tensorExps[e].children.e1); 461 llvm::dbgs() << ")"; 462 } 463 } 464 465 void Merger::dumpLat(unsigned p) const { 466 llvm::dbgs() << "lat("; 467 dumpBits(latPoints[p].bits); 468 llvm::dbgs() << " :"; 469 dumpBits(latPoints[p].simple); 470 llvm::dbgs() << " : "; 471 dumpExp(latPoints[p].exp); 472 llvm::dbgs() << " )\n"; 473 } 474 475 void Merger::dumpSet(unsigned s) const { 476 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 477 for (unsigned p : latSets[s]) { 478 llvm::dbgs() << " "; 479 dumpLat(p); 480 } 481 llvm::dbgs() << "}\n"; 482 } 483 484 void Merger::dumpBits(const BitVector &bits) const { 485 for (unsigned b = 0, be = bits.size(); b < be; b++) { 486 if (bits[b]) { 487 unsigned t = tensor(b); 488 unsigned i = index(b); 489 llvm::dbgs() << " i_" << t << "_" << i << "_"; 490 switch (dims[t][i]) { 491 case kSparse: 492 llvm::dbgs() << "S"; 493 break; 494 case kDense: 495 llvm::dbgs() << "D"; 496 break; 497 case kSingle: 498 llvm::dbgs() << "T"; 499 break; 500 case kUndef: 501 llvm::dbgs() << "U"; 502 break; 503 } 504 } 505 } 506 } 507 508 #endif // NDEBUG 509 510 //===----------------------------------------------------------------------===// 511 // Builder methods. 512 //===----------------------------------------------------------------------===// 513 514 unsigned Merger::buildLattices(unsigned e, unsigned i) { 515 Kind kind = tensorExps[e].kind; 516 switch (kind) { 517 case kTensor: 518 case kInvariant: 519 case kIndex: { 520 // Either the index is really used in the tensor expression, or it is 521 // set to the undefined index in that dimension. An invariant expression, 522 // a proper index value, and a truly dynamic sparse output tensor are set 523 // to a synthetic tensor with undefined indices only to ensure the 524 // iteration space is not skipped as a result of their contents. 525 unsigned s = addSet(); 526 unsigned t = syntheticTensor; 527 if (kind == kTensor) { 528 t = tensorExps[e].tensor; 529 if (hasSparseOut && t == outTensor) 530 t = syntheticTensor; 531 } 532 latSets[s].push_back(addLat(t, i, e)); 533 return s; 534 } 535 case kAbsF: 536 case kCeilF: 537 case kCIm: 538 case kCRe: 539 case kFloorF: 540 case kSqrtF: 541 case kExpm1F: 542 case kLog1pF: 543 case kSinF: 544 case kTanhF: 545 case kNegF: 546 case kNegI: 547 case kTruncF: 548 case kExtF: 549 case kCastFS: 550 case kCastFU: 551 case kCastSF: 552 case kCastUF: 553 case kCastS: 554 case kCastU: 555 case kCastIdx: 556 case kTruncI: 557 case kBitCast: 558 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 559 // lattice set of the operand through the operator into a new set. 560 // 561 // -y|!y | y | 562 // --+---+---+ 563 // | 0 |-y | 564 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 565 tensorExps[e].val); 566 case kBinaryBranch: 567 // The left or right half of a binary operation which has already 568 // been split into separate operations for each region. 569 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(), 570 tensorExps[e].op); 571 case kUnary: 572 // A custom unary operation. 573 // 574 // op y| !y | y | 575 // ----+----------+------------+ 576 // | absent() | present(y) | 577 { 578 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 579 UnaryOp unop = cast<UnaryOp>(tensorExps[e].op); 580 Region &absentRegion = unop.absentRegion(); 581 582 if (absentRegion.empty()) { 583 // Simple mapping over existing values. 584 return mapSet(kind, child0, Value(), unop); 585 } // Use a disjunction with `unop` on the left and the absent value as an 586 // invariant on the right. 587 Block &absentBlock = absentRegion.front(); 588 YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator()); 589 Value absentVal = absentYield.result(); 590 unsigned rhs = addExp(kInvariant, absentVal); 591 return takeDisj(kind, child0, buildLattices(rhs, i), unop); 592 } 593 case kMulF: 594 case kMulC: 595 case kMulI: 596 case kAndI: 597 // A multiplicative operation only needs to be performed 598 // for the conjunction of sparse iteration spaces. 599 // 600 // x*y|!y | y | 601 // ---+---+---+ 602 // !x | 0 | 0 | 603 // x | 0 |x*y| 604 // 605 // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored. 606 return takeConj(kind, // take binary conjunction 607 buildLattices(tensorExps[e].children.e0, i), 608 buildLattices(tensorExps[e].children.e1, i)); 609 case kDivF: 610 case kDivS: 611 case kDivU: 612 // A division is tricky, since 0/0, 0/c, c/0 all have 613 // specific outcomes for floating-point and integers. 614 // Thus, we need to traverse the full iteration space. 615 // 616 // x/y|!y | y | 617 // ---+---+---+ 618 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 619 // x |x/0|x/y| INT: x/0=exception for any x 620 // 621 // TODO: for now we "fixed" this by only accepting x/c cases 622 // during expression building, so that the conjunction 623 // rules applies (viz. x/c = x*(1/c) as far as lattice 624 // construction is concerned). 625 assert(!maybeZero(tensorExps[e].children.e1)); 626 return takeConj(kind, // take binary conjunction 627 buildLattices(tensorExps[e].children.e0, i), 628 buildLattices(tensorExps[e].children.e1, i)); 629 case kAddF: 630 case kAddC: 631 case kAddI: 632 case kSubF: 633 case kSubI: 634 case kOrI: 635 case kXorI: 636 // An additive operation needs to be performed 637 // for the disjunction of sparse iteration spaces. 638 // 639 // x+y|!y | y | x-y|!y | y | 640 // ---+---+---+ ---+---+---+ 641 // !x | 0 | y | !x | 0 |-y | 642 // x | x |x+y| x | x |x-y| 643 return takeDisj(kind, // take binary disjunction 644 buildLattices(tensorExps[e].children.e0, i), 645 buildLattices(tensorExps[e].children.e1, i)); 646 case kShrS: 647 case kShrU: 648 case kShlI: 649 // A shift operation by an invariant amount (viz. tensor expressions 650 // can only occur at the left-hand-side of the operator) can be handled 651 // with the conjuction rule. 652 assert(isInvariant(tensorExps[e].children.e1)); 653 return takeConj(kind, // take binary conjunction 654 buildLattices(tensorExps[e].children.e0, i), 655 buildLattices(tensorExps[e].children.e1, i)); 656 case kBinary: 657 // A custom binary operation. 658 // 659 // x op y| !y | y | 660 // ------+---------+--------------+ 661 // !x | empty | right(y) | 662 // x | left(x) | overlap(x,y) | 663 { 664 unsigned child0 = buildLattices(tensorExps[e].children.e0, i); 665 unsigned child1 = buildLattices(tensorExps[e].children.e1, i); 666 BinaryOp binop = cast<BinaryOp>(tensorExps[e].op); 667 Region &leftRegion = binop.leftRegion(); 668 Region &rightRegion = binop.rightRegion(); 669 // Left Region. 670 Operation *leftYield = nullptr; 671 if (!leftRegion.empty()) { 672 Block &leftBlock = leftRegion.front(); 673 leftYield = leftBlock.getTerminator(); 674 } 675 // Right Region. 676 Operation *rightYield = nullptr; 677 if (!rightRegion.empty()) { 678 Block &rightBlock = rightRegion.front(); 679 rightYield = rightBlock.getTerminator(); 680 } 681 bool includeLeft = binop.left_identity() || !leftRegion.empty(); 682 bool includeRight = binop.right_identity() || !rightRegion.empty(); 683 return takeCombi(kBinary, child0, child1, binop, includeLeft, 684 kBinaryBranch, leftYield, includeRight, kBinaryBranch, 685 rightYield); 686 } 687 } 688 llvm_unreachable("unexpected expression kind"); 689 } 690 691 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 692 Operation *yield = op.region().front().getTerminator(); 693 return buildTensorExp(op, yield->getOperand(0)); 694 } 695 696 /// Only returns false if we are certain this is a nonzero. 697 bool Merger::maybeZero(unsigned e) const { 698 if (tensorExps[e].kind == kInvariant) { 699 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 700 return c.value() == 0; 701 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 702 return c.value().isZero(); 703 } 704 return true; 705 } 706 707 bool Merger::isInvariant(unsigned e) const { 708 return tensorExps[e].kind == kInvariant; 709 } 710 711 Type Merger::inferType(unsigned e, Value src) { 712 // Obtain the destination type from the cast node. 713 Type dtp = tensorExps[e].val.getType(); 714 // Inspect source type. For vector types, apply the same 715 // vectorization to the destination type. 716 if (auto vtp = src.getType().dyn_cast<VectorType>()) 717 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 718 return dtp; 719 } 720 721 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 722 if (auto arg = v.dyn_cast<BlockArgument>()) { 723 unsigned argN = arg.getArgNumber(); 724 // Any argument of the generic op that is not marked as a scalar 725 // argument is considered a tensor, indexed by the implicit loop 726 // bounds. This includes rank-0 tensor arguments. 727 if (arg.getOwner()->getParentOp() == op) { 728 OpOperand *t = op.getInputAndOutputOperands()[argN]; 729 if (!op.isScalar(t)) 730 return addExp(kTensor, argN); 731 v = t->get(); // get scalar value 732 } 733 // Any other argument (marked as scalar argument for the generic op 734 // or belonging to an enveloping op) is considered invariant. 735 return addExp(kInvariant, v); 736 } 737 // Something defined outside is invariant. 738 Operation *def = v.getDefiningOp(); 739 if (def->getBlock() != &op.region().front()) 740 return addExp(kInvariant, v); 741 // Construct index operations. 742 if (def->getNumOperands() == 0) { 743 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 744 return addExp(kIndex, indexOp.dim()); 745 } 746 // Construct unary operations if subexpression can be built. 747 if (def->getNumOperands() == 1) { 748 auto x = buildTensorExp(op, def->getOperand(0)); 749 if (x.hasValue()) { 750 unsigned e = x.getValue(); 751 if (isa<math::AbsOp>(def)) 752 return addExp(kAbsF, e); 753 if (isa<math::CeilOp>(def)) 754 return addExp(kCeilF, e); 755 if (isa<math::FloorOp>(def)) 756 return addExp(kFloorF, e); 757 if (isa<math::SqrtOp>(def)) 758 return addExp(kSqrtF, e); 759 if (isa<math::ExpM1Op>(def)) 760 return addExp(kExpm1F, e); 761 if (isa<math::Log1pOp>(def)) 762 return addExp(kLog1pF, e); 763 if (isa<math::SinOp>(def)) 764 return addExp(kSinF, e); 765 if (isa<math::TanhOp>(def)) 766 return addExp(kTanhF, e); 767 if (isa<arith::NegFOp>(def)) 768 return addExp(kNegF, e); // no negi in std 769 if (isa<arith::TruncFOp>(def)) 770 return addExp(kTruncF, e, v); 771 if (isa<arith::ExtFOp>(def)) 772 return addExp(kExtF, e, v); 773 if (isa<arith::FPToSIOp>(def)) 774 return addExp(kCastFS, e, v); 775 if (isa<arith::FPToUIOp>(def)) 776 return addExp(kCastFU, e, v); 777 if (isa<arith::SIToFPOp>(def)) 778 return addExp(kCastSF, e, v); 779 if (isa<arith::UIToFPOp>(def)) 780 return addExp(kCastUF, e, v); 781 if (isa<arith::ExtSIOp>(def)) 782 return addExp(kCastS, e, v); 783 if (isa<arith::ExtUIOp>(def)) 784 return addExp(kCastU, e, v); 785 if (isa<arith::IndexCastOp>(def)) 786 return addExp(kCastIdx, e, v); 787 if (isa<arith::TruncIOp>(def)) 788 return addExp(kTruncI, e, v); 789 if (isa<complex::ImOp>(def)) 790 return addExp(kCIm, e); 791 if (isa<complex::ReOp>(def)) 792 return addExp(kCRe, e); 793 if (isa<arith::BitcastOp>(def)) 794 return addExp(kBitCast, e, v); 795 if (isa<sparse_tensor::UnaryOp>(def)) 796 return addExp(kUnary, e, Value(), def); 797 } 798 } 799 // Construct binary operations if subexpressions can be built. 800 // See buildLattices() for an explanation of rejecting certain 801 // division and shift operations 802 if (def->getNumOperands() == 2) { 803 auto x = buildTensorExp(op, def->getOperand(0)); 804 auto y = buildTensorExp(op, def->getOperand(1)); 805 if (x.hasValue() && y.hasValue()) { 806 unsigned e0 = x.getValue(); 807 unsigned e1 = y.getValue(); 808 if (isa<arith::MulFOp>(def)) 809 return addExp(kMulF, e0, e1); 810 if (isa<complex::MulOp>(def)) 811 return addExp(kMulC, e0, e1); 812 if (isa<arith::MulIOp>(def)) 813 return addExp(kMulI, e0, e1); 814 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 815 return addExp(kDivF, e0, e1); 816 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 817 return addExp(kDivS, e0, e1); 818 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 819 return addExp(kDivU, e0, e1); 820 if (isa<arith::AddFOp>(def)) 821 return addExp(kAddF, e0, e1); 822 if (isa<complex::AddOp>(def)) 823 return addExp(kAddC, e0, e1); 824 if (isa<arith::AddIOp>(def)) 825 return addExp(kAddI, e0, e1); 826 if (isa<arith::SubFOp>(def)) 827 return addExp(kSubF, e0, e1); 828 if (isa<arith::SubIOp>(def)) 829 return addExp(kSubI, e0, e1); 830 if (isa<arith::AndIOp>(def)) 831 return addExp(kAndI, e0, e1); 832 if (isa<arith::OrIOp>(def)) 833 return addExp(kOrI, e0, e1); 834 if (isa<arith::XOrIOp>(def)) 835 return addExp(kXorI, e0, e1); 836 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 837 return addExp(kShrS, e0, e1); 838 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 839 return addExp(kShrU, e0, e1); 840 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 841 return addExp(kShlI, e0, e1); 842 if (isa<sparse_tensor::BinaryOp>(def)) 843 return addExp(kBinary, e0, e1, Value(), def); 844 } 845 } 846 // Cannot build. 847 return None; 848 } 849 850 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region ®ion, 851 ValueRange vals) { 852 // Make a clone of overlap region. 853 Region tmpRegion; 854 BlockAndValueMapping mapper; 855 region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper); 856 Block &clonedBlock = tmpRegion.front(); 857 YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator()); 858 // Merge cloned block and return yield value. 859 Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0); 860 rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals); 861 Value val = clonedYield.result(); 862 rewriter.eraseOp(clonedYield); 863 rewriter.eraseOp(placeholder); 864 return val; 865 } 866 867 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc, 868 Operation *op, Value v0) { 869 if (!v0) 870 // Empty input value must be propagated. 871 return Value(); 872 UnaryOp unop = cast<UnaryOp>(op); 873 Region &presentRegion = unop.presentRegion(); 874 if (presentRegion.empty()) 875 // Uninitialized Value() will be interpreted as missing data in the 876 // output. 877 return Value(); 878 return insertYieldOp(rewriter, loc, presentRegion, {v0}); 879 } 880 881 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc, 882 Operation *op, Value v0, Value v1) { 883 if (!v0 || !v1) 884 // Empty input values must be propagated. 885 return Value(); 886 BinaryOp binop = cast<BinaryOp>(op); 887 Region &overlapRegion = binop.overlapRegion(); 888 if (overlapRegion.empty()) 889 // Uninitialized Value() will be interpreted as missing data in the 890 // output. 891 return Value(); 892 return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1}); 893 } 894 895 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e, 896 Value v0, Value v1) { 897 switch (tensorExps[e].kind) { 898 case kTensor: 899 case kInvariant: 900 case kIndex: 901 llvm_unreachable("unexpected non-op"); 902 // Unary ops. 903 case kAbsF: 904 return rewriter.create<math::AbsOp>(loc, v0); 905 case kCeilF: 906 return rewriter.create<math::CeilOp>(loc, v0); 907 case kFloorF: 908 return rewriter.create<math::FloorOp>(loc, v0); 909 case kSqrtF: 910 return rewriter.create<math::SqrtOp>(loc, v0); 911 case kExpm1F: 912 return rewriter.create<math::ExpM1Op>(loc, v0); 913 case kLog1pF: 914 return rewriter.create<math::Log1pOp>(loc, v0); 915 case kSinF: 916 return rewriter.create<math::SinOp>(loc, v0); 917 case kTanhF: 918 return rewriter.create<math::TanhOp>(loc, v0); 919 case kNegF: 920 return rewriter.create<arith::NegFOp>(loc, v0); 921 case kNegI: // no negi in std 922 return rewriter.create<arith::SubIOp>( 923 loc, 924 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 925 rewriter.getZeroAttr(v0.getType())), 926 v0); 927 case kTruncF: 928 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 929 case kExtF: 930 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 931 case kCastFS: 932 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 933 case kCastFU: 934 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 935 case kCastSF: 936 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 937 case kCastUF: 938 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 939 case kCastS: 940 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 941 case kCastU: 942 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 943 case kCastIdx: 944 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 945 case kTruncI: 946 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 947 case kCIm: 948 case kCRe: { 949 auto type = v0.getType().template cast<ComplexType>(); 950 auto eltType = type.getElementType().template cast<FloatType>(); 951 if (tensorExps[e].kind == kCIm) 952 return rewriter.create<complex::ImOp>(loc, eltType, v0); 953 954 return rewriter.create<complex::ReOp>(loc, eltType, v0); 955 } 956 case kBitCast: 957 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 958 // Binary ops. 959 case kMulF: 960 return rewriter.create<arith::MulFOp>(loc, v0, v1); 961 case kMulC: 962 return rewriter.create<complex::MulOp>(loc, v0, v1); 963 case kMulI: 964 return rewriter.create<arith::MulIOp>(loc, v0, v1); 965 case kDivF: 966 return rewriter.create<arith::DivFOp>(loc, v0, v1); 967 case kDivS: 968 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 969 case kDivU: 970 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 971 case kAddF: 972 return rewriter.create<arith::AddFOp>(loc, v0, v1); 973 case kAddC: 974 return rewriter.create<complex::AddOp>(loc, v0, v1); 975 case kAddI: 976 return rewriter.create<arith::AddIOp>(loc, v0, v1); 977 case kSubF: 978 return rewriter.create<arith::SubFOp>(loc, v0, v1); 979 case kSubI: 980 return rewriter.create<arith::SubIOp>(loc, v0, v1); 981 case kAndI: 982 return rewriter.create<arith::AndIOp>(loc, v0, v1); 983 case kOrI: 984 return rewriter.create<arith::OrIOp>(loc, v0, v1); 985 case kXorI: 986 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 987 case kShrS: 988 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 989 case kShrU: 990 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 991 case kShlI: 992 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 993 // Semiring ops with custom logic. 994 case kBinaryBranch: 995 return insertYieldOp(rewriter, loc, 996 *tensorExps[e].op->getBlock()->getParent(), {v0}); 997 case kUnary: 998 return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0); 999 case kBinary: 1000 return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1); 1001 } 1002 llvm_unreachable("unexpected expression kind in build"); 1003 } 1004 1005 } // namespace sparse_tensor 1006 } // namespace mlir 1007