1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 #include "mlir/Dialect/Math/IR/Math.h" 12 13 #include "mlir/IR/Operation.h" 14 #include "llvm/Support/Debug.h" 15 16 namespace mlir { 17 namespace sparse_tensor { 18 19 //===----------------------------------------------------------------------===// 20 // Constructors. 21 //===----------------------------------------------------------------------===// 22 23 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) 24 : kind(k), val(v) { 25 switch (kind) { 26 case kTensor: 27 assert(x != -1u && y == -1u && !v); 28 tensor = x; 29 break; 30 case kInvariant: 31 assert(x == -1u && y == -1u && v); 32 break; 33 case kIndex: 34 assert(x != -1u && y == -1u && !v); 35 index = x; 36 break; 37 case kAbsF: 38 case kCeilF: 39 case kFloorF: 40 case kNegF: 41 case kNegI: 42 assert(x != -1u && y == -1u && !v); 43 children.e0 = x; 44 children.e1 = y; 45 break; 46 case kTruncF: 47 case kExtF: 48 case kCastFS: 49 case kCastFU: 50 case kCastSF: 51 case kCastUF: 52 case kCastS: 53 case kCastU: 54 case kCastIdx: 55 case kTruncI: 56 case kBitCast: 57 assert(x != -1u && y == -1u && v); 58 children.e0 = x; 59 children.e1 = y; 60 break; 61 default: 62 assert(x != -1u && y != -1u && !v); 63 children.e0 = x; 64 children.e1 = y; 65 break; 66 } 67 } 68 69 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 70 : bits(n, false), simple(), exp(e) { 71 bits.set(b); 72 } 73 74 LatPoint::LatPoint(const BitVector &b, unsigned e) 75 : bits(b), simple(), exp(e) {} 76 77 //===----------------------------------------------------------------------===// 78 // Lattice methods. 79 //===----------------------------------------------------------------------===// 80 81 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { 82 unsigned e = tensorExps.size(); 83 tensorExps.push_back(TensorExp(k, e0, e1, v)); 84 return e; 85 } 86 87 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 88 assert(t < numTensors && i < numLoops); 89 unsigned p = latPoints.size(); 90 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 91 return p; 92 } 93 94 unsigned Merger::addSet() { 95 unsigned s = latSets.size(); 96 latSets.emplace_back(SmallVector<unsigned, 16>()); 97 return s; 98 } 99 100 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 101 unsigned p = latPoints.size(); 102 BitVector nb = BitVector(latPoints[p0].bits); 103 nb |= latPoints[p1].bits; 104 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 105 latPoints.push_back(LatPoint(nb, e)); 106 return p; 107 } 108 109 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { 110 unsigned s = addSet(); 111 for (unsigned p0 : latSets[s0]) 112 for (unsigned p1 : latSets[s1]) 113 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 114 return s; 115 } 116 117 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { 118 unsigned s = takeConj(kind, s0, s1); 119 // Followed by all in s0. 120 for (unsigned p : latSets[s0]) 121 latSets[s].push_back(p); 122 // Map binary 0-y to unary -y. 123 if (kind == kSubF) 124 s1 = mapSet(kNegF, s1); 125 else if (kind == kSubI) 126 s1 = mapSet(kNegI, s1); 127 // Followed by all in s1. 128 for (unsigned p : latSets[s1]) 129 latSets[s].push_back(p); 130 return s; 131 } 132 133 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) { 134 assert(kAbsF <= kind && kind <= kBitCast); 135 unsigned s = addSet(); 136 for (unsigned p : latSets[s0]) { 137 unsigned e = addExp(kind, latPoints[p].exp, v); 138 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 139 latSets[s].push_back(latPoints.size() - 1); 140 } 141 return s; 142 } 143 144 unsigned Merger::optimizeSet(unsigned s0) { 145 unsigned s = addSet(); 146 assert(!latSets[s0].empty()); 147 unsigned p0 = latSets[s0][0]; 148 for (unsigned p1 : latSets[s0]) { 149 bool add = true; 150 if (p0 != p1) { 151 // Is this a straightforward copy? 152 unsigned e = latPoints[p1].exp; 153 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 154 continue; 155 // Conjunction already covered? 156 for (unsigned p2 : latSets[s]) { 157 assert(!latGT(p1, p2)); // Lj => Li would be bad 158 if (onlyDenseDiff(p2, p1)) { 159 add = false; 160 break; 161 } 162 } 163 assert(!add || latGT(p0, p1)); 164 } 165 if (add) 166 latSets[s].push_back(p1); 167 } 168 for (unsigned p : latSets[s]) 169 latPoints[p].simple = simplifyCond(s, p); 170 return s; 171 } 172 173 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 174 // First determine if this lattice point is a *singleton*, i.e., 175 // the last point in a lattice, no other is less than this one. 176 bool isSingleton = true; 177 for (unsigned p1 : latSets[s0]) { 178 if (p0 != p1 && latGT(p0, p1)) { 179 isSingleton = false; 180 break; 181 } 182 } 183 // Now apply the two basic rules. 184 BitVector simple = latPoints[p0].bits; 185 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 186 for (unsigned b = 0, be = simple.size(); b < be; b++) { 187 if (simple[b] && !isDim(b, kSparse)) { 188 if (reset) 189 simple.reset(b); 190 reset = true; 191 } 192 } 193 return simple; 194 } 195 196 bool Merger::latGT(unsigned i, unsigned j) const { 197 const BitVector &bitsi = latPoints[i].bits; 198 const BitVector &bitsj = latPoints[j].bits; 199 assert(bitsi.size() == bitsj.size()); 200 if (bitsi.count() > bitsj.count()) { 201 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 202 if (bitsj[b] && !bitsi[b]) 203 return false; 204 return true; 205 } 206 return false; 207 } 208 209 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 210 BitVector tmp = latPoints[j].bits; 211 tmp ^= latPoints[i].bits; 212 return !hasAnyDimOf(tmp, kSparse); 213 } 214 215 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { 216 for (unsigned b = 0, be = bits.size(); b < be; b++) 217 if (bits[b] && isDim(b, d)) 218 return true; 219 return false; 220 } 221 222 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 223 switch (tensorExps[e].kind) { 224 case kTensor: 225 return tensorExps[e].tensor == t; 226 case kAbsF: 227 case kCeilF: 228 case kFloorF: 229 case kNegF: 230 case kNegI: 231 case kTruncF: 232 case kExtF: 233 case kCastFS: 234 case kCastFU: 235 case kCastSF: 236 case kCastUF: 237 case kCastS: 238 case kCastU: 239 case kCastIdx: 240 case kTruncI: 241 case kBitCast: 242 return isSingleCondition(t, tensorExps[e].children.e0); 243 case kDivF: // note: x / c only 244 case kDivS: 245 case kDivU: 246 assert(!maybeZero(tensorExps[e].children.e1)); 247 return isSingleCondition(t, tensorExps[e].children.e0); 248 case kShrS: // note: x >> inv only 249 case kShrU: 250 case kShlI: 251 assert(isInvariant(tensorExps[e].children.e1)); 252 return isSingleCondition(t, tensorExps[e].children.e0); 253 case kMulF: 254 case kMulI: 255 case kAndI: 256 if (isSingleCondition(t, tensorExps[e].children.e0)) 257 return isSingleCondition(t, tensorExps[e].children.e1) || 258 isInvariant(tensorExps[e].children.e1); 259 if (isSingleCondition(t, tensorExps[e].children.e1)) 260 return isInvariant(tensorExps[e].children.e0); 261 return false; 262 case kAddF: 263 case kAddI: 264 return isSingleCondition(t, tensorExps[e].children.e0) && 265 isSingleCondition(t, tensorExps[e].children.e1); 266 default: 267 return false; 268 } 269 } 270 271 #ifndef NDEBUG 272 273 //===----------------------------------------------------------------------===// 274 // Print methods (for debugging). 275 //===----------------------------------------------------------------------===// 276 277 static const char *kindToOpSymbol(Kind kind) { 278 switch (kind) { 279 case kTensor: 280 return "tensor"; 281 case kInvariant: 282 return "invariant"; 283 case kIndex: 284 return "index"; 285 case kAbsF: 286 return "abs"; 287 case kCeilF: 288 return "ceil"; 289 case kFloorF: 290 return "floor"; 291 case kNegF: 292 return "-"; 293 case kNegI: 294 return "-"; 295 case kTruncF: 296 case kExtF: 297 case kCastFS: 298 case kCastFU: 299 case kCastSF: 300 case kCastUF: 301 case kCastS: 302 case kCastU: 303 case kCastIdx: 304 case kTruncI: 305 case kBitCast: 306 return "cast"; 307 case kMulF: 308 return "*"; 309 case kMulI: 310 return "*"; 311 case kDivF: 312 return "/"; 313 case kDivS: 314 return "/"; 315 case kDivU: 316 return "/"; 317 case kAddF: 318 return "+"; 319 case kAddI: 320 return "+"; 321 case kSubF: 322 return "-"; 323 case kSubI: 324 return "-"; 325 case kAndI: 326 return "&"; 327 case kOrI: 328 return "|"; 329 case kXorI: 330 return "^"; 331 case kShrS: 332 return "a>>"; 333 case kShrU: 334 return ">>"; 335 case kShlI: 336 return "<<"; 337 } 338 llvm_unreachable("unexpected kind for symbol"); 339 } 340 341 void Merger::dumpExp(unsigned e) const { 342 switch (tensorExps[e].kind) { 343 case kTensor: 344 if (tensorExps[e].tensor == syntheticTensor) 345 llvm::dbgs() << "synthetic_"; 346 else if (tensorExps[e].tensor == outTensor) 347 llvm::dbgs() << "output_"; 348 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 349 break; 350 case kInvariant: 351 llvm::dbgs() << "invariant"; 352 break; 353 case kIndex: 354 llvm::dbgs() << "index_" << tensorExps[e].index; 355 break; 356 case kAbsF: 357 case kCeilF: 358 case kFloorF: 359 case kNegF: 360 case kNegI: 361 case kTruncF: 362 case kExtF: 363 case kCastFS: 364 case kCastFU: 365 case kCastSF: 366 case kCastUF: 367 case kCastS: 368 case kCastU: 369 case kCastIdx: 370 case kTruncI: 371 case kBitCast: 372 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 373 dumpExp(tensorExps[e].children.e0); 374 break; 375 default: 376 llvm::dbgs() << "("; 377 dumpExp(tensorExps[e].children.e0); 378 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 379 dumpExp(tensorExps[e].children.e1); 380 llvm::dbgs() << ")"; 381 } 382 } 383 384 void Merger::dumpLat(unsigned p) const { 385 llvm::dbgs() << "lat("; 386 dumpBits(latPoints[p].bits); 387 llvm::dbgs() << " :"; 388 dumpBits(latPoints[p].simple); 389 llvm::dbgs() << " : "; 390 dumpExp(latPoints[p].exp); 391 llvm::dbgs() << " )\n"; 392 } 393 394 void Merger::dumpSet(unsigned s) const { 395 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 396 for (unsigned p : latSets[s]) { 397 llvm::dbgs() << " "; 398 dumpLat(p); 399 } 400 llvm::dbgs() << "}\n"; 401 } 402 403 void Merger::dumpBits(const BitVector &bits) const { 404 for (unsigned b = 0, be = bits.size(); b < be; b++) { 405 if (bits[b]) { 406 unsigned t = tensor(b); 407 unsigned i = index(b); 408 llvm::dbgs() << " i_" << t << "_" << i << "_"; 409 switch (dims[t][i]) { 410 case kSparse: 411 llvm::dbgs() << "S"; 412 break; 413 case kDense: 414 llvm::dbgs() << "D"; 415 break; 416 case kSingle: 417 llvm::dbgs() << "T"; 418 break; 419 case kUndef: 420 llvm::dbgs() << "U"; 421 break; 422 } 423 } 424 } 425 } 426 427 #endif // NDEBUG 428 429 //===----------------------------------------------------------------------===// 430 // Builder methods. 431 //===----------------------------------------------------------------------===// 432 433 unsigned Merger::buildLattices(unsigned e, unsigned i) { 434 Kind kind = tensorExps[e].kind; 435 switch (kind) { 436 case kTensor: 437 case kInvariant: 438 case kIndex: { 439 // Either the index is really used in the tensor expression, or it is 440 // set to the undefined index in that dimension. An invariant expression, 441 // a proper index value, and a truly dynamic sparse output tensor are set 442 // to a synthetic tensor with undefined indices only to ensure the 443 // iteration space is not skipped as a result of their contents. 444 unsigned s = addSet(); 445 unsigned t = syntheticTensor; 446 if (kind == kTensor) { 447 t = tensorExps[e].tensor; 448 if (hasSparseOut && t == outTensor) 449 t = syntheticTensor; 450 } 451 latSets[s].push_back(addLat(t, i, e)); 452 return s; 453 } 454 case kAbsF: 455 case kCeilF: 456 case kFloorF: 457 case kNegF: 458 case kNegI: 459 case kTruncF: 460 case kExtF: 461 case kCastFS: 462 case kCastFU: 463 case kCastSF: 464 case kCastUF: 465 case kCastS: 466 case kCastU: 467 case kCastIdx: 468 case kTruncI: 469 case kBitCast: 470 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 471 // lattice set of the operand through the operator into a new set. 472 // 473 // -y|!y | y | 474 // --+---+---+ 475 // | 0 |-y | 476 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 477 tensorExps[e].val); 478 case kMulF: 479 case kMulI: 480 case kAndI: 481 // A multiplicative operation only needs to be performed 482 // for the conjunction of sparse iteration spaces. 483 // 484 // x*y|!y | y | 485 // ---+---+---+ 486 // !x | 0 | 0 | 487 // x | 0 |x*y| 488 return takeConj(kind, // take binary conjunction 489 buildLattices(tensorExps[e].children.e0, i), 490 buildLattices(tensorExps[e].children.e1, i)); 491 case kDivF: 492 case kDivS: 493 case kDivU: 494 // A division is tricky, since 0/0, 0/c, c/0 all have 495 // specific outcomes for floating-point and integers. 496 // Thus, we need to traverse the full iteration space. 497 // 498 // x/y|!y | y | 499 // ---+---+---+ 500 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 501 // x |x/0|x/y| INT: x/0=exception for any x 502 // 503 // TODO: for now we "fixed" this by only accepting x/c cases 504 // during expression building, so that the conjunction 505 // rules applies (viz. x/c = x*(1/c) as far as lattice 506 // construction is concerned). 507 assert(!maybeZero(tensorExps[e].children.e1)); 508 return takeConj(kind, // take binary conjunction 509 buildLattices(tensorExps[e].children.e0, i), 510 buildLattices(tensorExps[e].children.e1, i)); 511 case kAddF: 512 case kAddI: 513 case kSubF: 514 case kSubI: 515 case kOrI: 516 case kXorI: 517 // An additive operation needs to be performed 518 // for the disjunction of sparse iteration spaces. 519 // 520 // x+y|!y | y | x-y|!y | y | 521 // ---+---+---+ ---+---+---+ 522 // !x | 0 | y | !x | 0 |-y | 523 // x | x |x+y| x | x |x-y| 524 return takeDisj(kind, // take binary disjunction 525 buildLattices(tensorExps[e].children.e0, i), 526 buildLattices(tensorExps[e].children.e1, i)); 527 case kShrS: 528 case kShrU: 529 case kShlI: 530 // A shift operation by an invariant amount (viz. tensor expressions 531 // can only occur at the left-hand-side of the operator) can be handled 532 // with the conjuction rule. 533 assert(isInvariant(tensorExps[e].children.e1)); 534 return takeConj(kind, // take binary conjunction 535 buildLattices(tensorExps[e].children.e0, i), 536 buildLattices(tensorExps[e].children.e1, i)); 537 } 538 llvm_unreachable("unexpected expression kind"); 539 } 540 541 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 542 Operation *yield = op.region().front().getTerminator(); 543 return buildTensorExp(op, yield->getOperand(0)); 544 } 545 546 /// Only returns false if we are certain this is a nonzero. 547 bool Merger::maybeZero(unsigned e) const { 548 if (tensorExps[e].kind == kInvariant) { 549 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 550 return c.value() == 0; 551 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 552 return c.value().isZero(); 553 } 554 return true; 555 } 556 557 bool Merger::isInvariant(unsigned e) const { 558 return tensorExps[e].kind == kInvariant; 559 } 560 561 Type Merger::inferType(unsigned e, Value src) { 562 // Obtain the destination type from the cast node. 563 Type dtp = tensorExps[e].val.getType(); 564 // Inspect source type. For vector types, apply the same 565 // vectorization to the destination type. 566 if (auto vtp = src.getType().dyn_cast<VectorType>()) 567 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 568 return dtp; 569 } 570 571 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 572 if (auto arg = v.dyn_cast<BlockArgument>()) { 573 unsigned argN = arg.getArgNumber(); 574 // Any argument of the generic op that is not marked as a scalar 575 // argument is considered a tensor, indexed by the implicit loop 576 // bounds. This includes rank-0 tensor arguments. 577 if (arg.getOwner()->getParentOp() == op) { 578 OpOperand *t = op.getInputAndOutputOperands()[argN]; 579 if (!op.isScalar(t)) 580 return addExp(kTensor, argN); 581 v = t->get(); // get scalar value 582 } 583 // Any other argument (marked as scalar argument for the generic op 584 // or belonging to an enveloping op) is considered invariant. 585 return addExp(kInvariant, v); 586 } 587 // Something defined outside is invariant. 588 Operation *def = v.getDefiningOp(); 589 if (def->getBlock() != &op.region().front()) 590 return addExp(kInvariant, v); 591 // Construct index operations. 592 if (def->getNumOperands() == 0) { 593 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 594 return addExp(kIndex, indexOp.dim()); 595 } 596 // Construct unary operations if subexpression can be built. 597 if (def->getNumOperands() == 1) { 598 auto x = buildTensorExp(op, def->getOperand(0)); 599 if (x.hasValue()) { 600 unsigned e = x.getValue(); 601 if (isa<math::AbsOp>(def)) 602 return addExp(kAbsF, e); 603 if (isa<math::CeilOp>(def)) 604 return addExp(kCeilF, e); 605 if (isa<math::FloorOp>(def)) 606 return addExp(kFloorF, e); 607 if (isa<arith::NegFOp>(def)) 608 return addExp(kNegF, e); // no negi in std 609 if (isa<arith::TruncFOp>(def)) 610 return addExp(kTruncF, e, v); 611 if (isa<arith::ExtFOp>(def)) 612 return addExp(kExtF, e, v); 613 if (isa<arith::FPToSIOp>(def)) 614 return addExp(kCastFS, e, v); 615 if (isa<arith::FPToUIOp>(def)) 616 return addExp(kCastFU, e, v); 617 if (isa<arith::SIToFPOp>(def)) 618 return addExp(kCastSF, e, v); 619 if (isa<arith::UIToFPOp>(def)) 620 return addExp(kCastUF, e, v); 621 if (isa<arith::ExtSIOp>(def)) 622 return addExp(kCastS, e, v); 623 if (isa<arith::ExtUIOp>(def)) 624 return addExp(kCastU, e, v); 625 if (isa<arith::IndexCastOp>(def)) 626 return addExp(kCastIdx, e, v); 627 if (isa<arith::TruncIOp>(def)) 628 return addExp(kTruncI, e, v); 629 if (isa<arith::BitcastOp>(def)) 630 return addExp(kBitCast, e, v); 631 } 632 } 633 // Construct binary operations if subexpressions can be built. 634 // See buildLattices() for an explanation of rejecting certain 635 // division and shift operations 636 if (def->getNumOperands() == 2) { 637 auto x = buildTensorExp(op, def->getOperand(0)); 638 auto y = buildTensorExp(op, def->getOperand(1)); 639 if (x.hasValue() && y.hasValue()) { 640 unsigned e0 = x.getValue(); 641 unsigned e1 = y.getValue(); 642 if (isa<arith::MulFOp>(def)) 643 return addExp(kMulF, e0, e1); 644 if (isa<arith::MulIOp>(def)) 645 return addExp(kMulI, e0, e1); 646 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 647 return addExp(kDivF, e0, e1); 648 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 649 return addExp(kDivS, e0, e1); 650 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 651 return addExp(kDivU, e0, e1); 652 if (isa<arith::AddFOp>(def)) 653 return addExp(kAddF, e0, e1); 654 if (isa<arith::AddIOp>(def)) 655 return addExp(kAddI, e0, e1); 656 if (isa<arith::SubFOp>(def)) 657 return addExp(kSubF, e0, e1); 658 if (isa<arith::SubIOp>(def)) 659 return addExp(kSubI, e0, e1); 660 if (isa<arith::AndIOp>(def)) 661 return addExp(kAndI, e0, e1); 662 if (isa<arith::OrIOp>(def)) 663 return addExp(kOrI, e0, e1); 664 if (isa<arith::XOrIOp>(def)) 665 return addExp(kXorI, e0, e1); 666 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 667 return addExp(kShrS, e0, e1); 668 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 669 return addExp(kShrU, e0, e1); 670 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 671 return addExp(kShlI, e0, e1); 672 } 673 } 674 // Cannot build. 675 return None; 676 } 677 678 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, 679 Value v0, Value v1) { 680 switch (tensorExps[e].kind) { 681 case kTensor: 682 case kInvariant: 683 case kIndex: 684 llvm_unreachable("unexpected non-op"); 685 // Unary ops. 686 case kAbsF: 687 return rewriter.create<math::AbsOp>(loc, v0); 688 case kCeilF: 689 return rewriter.create<math::CeilOp>(loc, v0); 690 case kFloorF: 691 return rewriter.create<math::FloorOp>(loc, v0); 692 case kNegF: 693 return rewriter.create<arith::NegFOp>(loc, v0); 694 case kNegI: // no negi in std 695 return rewriter.create<arith::SubIOp>( 696 loc, 697 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 698 rewriter.getZeroAttr(v0.getType())), 699 v0); 700 case kTruncF: 701 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 702 case kExtF: 703 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 704 case kCastFS: 705 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 706 case kCastFU: 707 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 708 case kCastSF: 709 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 710 case kCastUF: 711 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 712 case kCastS: 713 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 714 case kCastU: 715 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 716 case kCastIdx: 717 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 718 case kTruncI: 719 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 720 case kBitCast: 721 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 722 // Binary ops. 723 case kMulF: 724 return rewriter.create<arith::MulFOp>(loc, v0, v1); 725 case kMulI: 726 return rewriter.create<arith::MulIOp>(loc, v0, v1); 727 case kDivF: 728 return rewriter.create<arith::DivFOp>(loc, v0, v1); 729 case kDivS: 730 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 731 case kDivU: 732 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 733 case kAddF: 734 return rewriter.create<arith::AddFOp>(loc, v0, v1); 735 case kAddI: 736 return rewriter.create<arith::AddIOp>(loc, v0, v1); 737 case kSubF: 738 return rewriter.create<arith::SubFOp>(loc, v0, v1); 739 case kSubI: 740 return rewriter.create<arith::SubIOp>(loc, v0, v1); 741 case kAndI: 742 return rewriter.create<arith::AndIOp>(loc, v0, v1); 743 case kOrI: 744 return rewriter.create<arith::OrIOp>(loc, v0, v1); 745 case kXorI: 746 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 747 case kShrS: 748 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 749 case kShrU: 750 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 751 case kShlI: 752 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 753 } 754 llvm_unreachable("unexpected expression kind in build"); 755 } 756 757 } // namespace sparse_tensor 758 } // namespace mlir 759