1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 12 #include "mlir/IR/Operation.h" 13 #include "llvm/Support/Debug.h" 14 15 namespace mlir { 16 namespace sparse_tensor { 17 18 //===----------------------------------------------------------------------===// 19 // Constructors. 20 //===----------------------------------------------------------------------===// 21 22 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) 23 : kind(k), val(v) { 24 switch (kind) { 25 case kTensor: 26 assert(x != -1u && y == -1u && !v); 27 tensor = x; 28 break; 29 case kInvariant: 30 assert(x == -1u && y == -1u && v); 31 break; 32 case kAbsF: 33 case kCeilF: 34 case kFloorF: 35 case kNegF: 36 case kNegI: 37 assert(x != -1u && y == -1u && !v); 38 children.e0 = x; 39 children.e1 = y; 40 break; 41 case kTruncF: 42 case kExtF: 43 case kCastFS: 44 case kCastFU: 45 case kCastSF: 46 case kCastUF: 47 case kCastS: 48 case kCastU: 49 case kTruncI: 50 case kBitCast: 51 assert(x != -1u && y == -1u && v); 52 children.e0 = x; 53 children.e1 = y; 54 break; 55 default: 56 assert(x != -1u && y != -1u && !v); 57 children.e0 = x; 58 children.e1 = y; 59 break; 60 } 61 } 62 63 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 64 : bits(n, false), simple(), exp(e) { 65 bits.set(b); 66 } 67 68 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e) 69 : bits(b), simple(), exp(e) {} 70 71 //===----------------------------------------------------------------------===// 72 // Lattice methods. 73 //===----------------------------------------------------------------------===// 74 75 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { 76 unsigned e = tensorExps.size(); 77 tensorExps.push_back(TensorExp(k, e0, e1, v)); 78 return e; 79 } 80 81 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 82 assert(t < numTensors && i < numLoops); 83 unsigned p = latPoints.size(); 84 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 85 return p; 86 } 87 88 unsigned Merger::addSet() { 89 unsigned s = latSets.size(); 90 latSets.emplace_back(SmallVector<unsigned, 16>()); 91 return s; 92 } 93 94 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 95 unsigned p = latPoints.size(); 96 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 97 nb |= latPoints[p1].bits; 98 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 99 latPoints.push_back(LatPoint(nb, e)); 100 return p; 101 } 102 103 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { 104 unsigned s = addSet(); 105 for (unsigned p0 : latSets[s0]) 106 for (unsigned p1 : latSets[s1]) 107 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 108 return s; 109 } 110 111 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { 112 unsigned s = takeConj(kind, s0, s1); 113 // Followed by all in s0. 114 for (unsigned p : latSets[s0]) 115 latSets[s].push_back(p); 116 // Map binary 0-y to unary -y. 117 if (kind == kSubF) 118 s1 = mapSet(kNegF, s1); 119 else if (kind == kSubI) 120 s1 = mapSet(kNegI, s1); 121 // Followed by all in s1. 122 for (unsigned p : latSets[s1]) 123 latSets[s].push_back(p); 124 return s; 125 } 126 127 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) { 128 assert(kAbsF <= kind && kind <= kBitCast); 129 unsigned s = addSet(); 130 for (unsigned p : latSets[s0]) { 131 unsigned e = addExp(kind, latPoints[p].exp, v); 132 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 133 latSets[s].push_back(latPoints.size() - 1); 134 } 135 return s; 136 } 137 138 unsigned Merger::optimizeSet(unsigned s0) { 139 unsigned s = addSet(); 140 assert(latSets[s0].size() != 0); 141 unsigned p0 = latSets[s0][0]; 142 for (unsigned p1 : latSets[s0]) { 143 bool add = true; 144 if (p0 != p1) { 145 // Is this a straightforward copy? 146 unsigned e = latPoints[p1].exp; 147 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 148 continue; 149 // Conjunction already covered? 150 for (unsigned p2 : latSets[s]) { 151 assert(!latGT(p1, p2)); // Lj => Li would be bad 152 if (onlyDenseDiff(p2, p1)) { 153 add = false; 154 break; 155 } 156 } 157 assert(!add || latGT(p0, p1)); 158 } 159 if (add) 160 latSets[s].push_back(p1); 161 } 162 for (unsigned p : latSets[s]) 163 latPoints[p].simple = simplifyCond(s, p); 164 return s; 165 } 166 167 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 168 // First determine if this lattice point is a *singleton*, i.e., 169 // the last point in a lattice, no other is less than this one. 170 bool isSingleton = true; 171 for (unsigned p1 : latSets[s0]) { 172 if (p0 != p1 && latGT(p0, p1)) { 173 isSingleton = false; 174 break; 175 } 176 } 177 // Now apply the two basic rules. 178 llvm::BitVector simple = latPoints[p0].bits; 179 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 180 for (unsigned b = 0, be = simple.size(); b < be; b++) { 181 if (simple[b] && !isDim(b, kSparse)) { 182 if (reset) 183 simple.reset(b); 184 reset = true; 185 } 186 } 187 return simple; 188 } 189 190 bool Merger::latGT(unsigned i, unsigned j) const { 191 const llvm::BitVector &bitsi = latPoints[i].bits; 192 const llvm::BitVector &bitsj = latPoints[j].bits; 193 assert(bitsi.size() == bitsj.size()); 194 if (bitsi.count() > bitsj.count()) { 195 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 196 if (bitsj[b] && !bitsi[b]) 197 return false; 198 return true; 199 } 200 return false; 201 } 202 203 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 204 llvm::BitVector tmp = latPoints[j].bits; 205 tmp ^= latPoints[i].bits; 206 return !hasAnyDimOf(tmp, kSparse); 207 } 208 209 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 210 for (unsigned b = 0, be = bits.size(); b < be; b++) 211 if (bits[b] && isDim(b, d)) 212 return true; 213 return false; 214 } 215 216 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 217 switch (tensorExps[e].kind) { 218 case kTensor: 219 return tensorExps[e].tensor == t; 220 case kAbsF: 221 case kCeilF: 222 case kFloorF: 223 case kNegF: 224 case kNegI: 225 case kTruncF: 226 case kExtF: 227 case kCastFS: 228 case kCastFU: 229 case kCastSF: 230 case kCastUF: 231 case kCastS: 232 case kCastU: 233 case kTruncI: 234 case kBitCast: 235 return isSingleCondition(t, tensorExps[e].children.e0); 236 case kDivF: // note: x / c only 237 case kDivS: 238 case kDivU: 239 assert(!maybeZero(tensorExps[e].children.e1)); 240 return isSingleCondition(t, tensorExps[e].children.e0); 241 case kShrS: // note: x >> inv only 242 case kShrU: 243 case kShlI: 244 assert(isInvariant(tensorExps[e].children.e1)); 245 return isSingleCondition(t, tensorExps[e].children.e0); 246 case kMulF: 247 case kMulI: 248 case kAndI: 249 if (isSingleCondition(t, tensorExps[e].children.e0)) 250 return isSingleCondition(t, tensorExps[e].children.e1) || 251 isInvariant(tensorExps[e].children.e1); 252 if (isSingleCondition(t, tensorExps[e].children.e1)) 253 return isInvariant(tensorExps[e].children.e0); 254 return false; 255 case kAddF: 256 case kAddI: 257 return isSingleCondition(t, tensorExps[e].children.e0) && 258 isSingleCondition(t, tensorExps[e].children.e1); 259 default: 260 return false; 261 } 262 } 263 264 #ifndef NDEBUG 265 266 //===----------------------------------------------------------------------===// 267 // Print methods (for debugging). 268 //===----------------------------------------------------------------------===// 269 270 static const char *kindToOpSymbol(Kind kind) { 271 switch (kind) { 272 case kTensor: 273 return "tensor"; 274 case kInvariant: 275 return "invariant"; 276 case kAbsF: 277 return "abs"; 278 case kCeilF: 279 return "ceil"; 280 case kFloorF: 281 return "floor"; 282 case kNegF: 283 return "-"; 284 case kNegI: 285 return "-"; 286 case kTruncF: 287 case kExtF: 288 case kCastFS: 289 case kCastFU: 290 case kCastSF: 291 case kCastUF: 292 case kCastS: 293 case kCastU: 294 case kTruncI: 295 case kBitCast: 296 return "cast"; 297 case kMulF: 298 return "*"; 299 case kMulI: 300 return "*"; 301 case kDivF: 302 return "/"; 303 case kDivS: 304 return "/"; 305 case kDivU: 306 return "/"; 307 case kAddF: 308 return "+"; 309 case kAddI: 310 return "+"; 311 case kSubF: 312 return "-"; 313 case kSubI: 314 return "-"; 315 case kAndI: 316 return "&"; 317 case kOrI: 318 return "|"; 319 case kXorI: 320 return "^"; 321 case kShrS: 322 return "a>>"; 323 case kShrU: 324 return ">>"; 325 case kShlI: 326 return "<<"; 327 } 328 llvm_unreachable("unexpected kind for symbol"); 329 } 330 331 void Merger::dumpExp(unsigned e) const { 332 switch (tensorExps[e].kind) { 333 case kTensor: 334 if (tensorExps[e].tensor == syntheticTensor) 335 llvm::dbgs() << "synthetic_"; 336 else if (tensorExps[e].tensor == outTensor) 337 llvm::dbgs() << "output_"; 338 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 339 break; 340 case kInvariant: 341 llvm::dbgs() << "invariant"; 342 break; 343 case kAbsF: 344 case kCeilF: 345 case kFloorF: 346 case kNegF: 347 case kNegI: 348 case kTruncF: 349 case kExtF: 350 case kCastFS: 351 case kCastFU: 352 case kCastSF: 353 case kCastUF: 354 case kCastS: 355 case kCastU: 356 case kTruncI: 357 case kBitCast: 358 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 359 dumpExp(tensorExps[e].children.e0); 360 break; 361 default: 362 llvm::dbgs() << "("; 363 dumpExp(tensorExps[e].children.e0); 364 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 365 dumpExp(tensorExps[e].children.e1); 366 llvm::dbgs() << ")"; 367 } 368 } 369 370 void Merger::dumpLat(unsigned p) const { 371 llvm::dbgs() << "lat("; 372 dumpBits(latPoints[p].bits); 373 llvm::dbgs() << " :"; 374 dumpBits(latPoints[p].simple); 375 llvm::dbgs() << " : "; 376 dumpExp(latPoints[p].exp); 377 llvm::dbgs() << " )\n"; 378 } 379 380 void Merger::dumpSet(unsigned s) const { 381 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 382 for (unsigned p : latSets[s]) { 383 llvm::dbgs() << " "; 384 dumpLat(p); 385 } 386 llvm::dbgs() << "}\n"; 387 } 388 389 void Merger::dumpBits(const llvm::BitVector &bits) const { 390 for (unsigned b = 0, be = bits.size(); b < be; b++) { 391 if (bits[b]) { 392 unsigned t = tensor(b); 393 unsigned i = index(b); 394 llvm::dbgs() << " i_" << t << "_" << i << "_"; 395 switch (dims[t][i]) { 396 case kSparse: 397 llvm::dbgs() << "S"; 398 break; 399 case kDense: 400 llvm::dbgs() << "D"; 401 break; 402 case kSingle: 403 llvm::dbgs() << "T"; 404 break; 405 case kUndef: 406 llvm::dbgs() << "U"; 407 break; 408 } 409 } 410 } 411 } 412 413 #endif // NDEBUG 414 415 //===----------------------------------------------------------------------===// 416 // Builder methods. 417 //===----------------------------------------------------------------------===// 418 419 unsigned Merger::buildLattices(unsigned e, unsigned i) { 420 Kind kind = tensorExps[e].kind; 421 switch (kind) { 422 case kTensor: 423 case kInvariant: { 424 // Either the index is really used in the tensor expression, or it is 425 // set to the undefined index in that dimension. An invariant expression 426 // and a truly dynamic sparse output tensor are set to a synthetic tensor 427 // with undefined indices only to ensure the iteration space is not 428 // skipped as a result of their contents. 429 unsigned s = addSet(); 430 unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor; 431 if (hasSparseOut && t == outTensor) 432 t = syntheticTensor; 433 latSets[s].push_back(addLat(t, i, e)); 434 return s; 435 } 436 case kAbsF: 437 case kCeilF: 438 case kFloorF: 439 case kNegF: 440 case kNegI: 441 case kTruncF: 442 case kExtF: 443 case kCastFS: 444 case kCastFU: 445 case kCastSF: 446 case kCastUF: 447 case kCastS: 448 case kCastU: 449 case kTruncI: 450 case kBitCast: 451 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 452 // lattice set of the operand through the operator into a new set. 453 // 454 // -y|!y | y | 455 // --+---+---+ 456 // | 0 |-y | 457 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 458 tensorExps[e].val); 459 case kMulF: 460 case kMulI: 461 case kAndI: 462 // A multiplicative operation only needs to be performed 463 // for the conjunction of sparse iteration spaces. 464 // 465 // x*y|!y | y | 466 // ---+---+---+ 467 // !x | 0 | 0 | 468 // x | 0 |x*y| 469 return takeConj(kind, // take binary conjunction 470 buildLattices(tensorExps[e].children.e0, i), 471 buildLattices(tensorExps[e].children.e1, i)); 472 case kDivF: 473 case kDivS: 474 case kDivU: 475 // A division is tricky, since 0/0, 0/c, c/0 all have 476 // specific outcomes for floating-point and integers. 477 // Thus, we need to traverse the full iteration space. 478 // 479 // x/y|!y | y | 480 // ---+---+---+ 481 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 482 // x |x/0|x/y| INT: x/0=exception for any x 483 // 484 // TODO: for now we "fixed" this by only accepting x/c cases 485 // during expression building, so that the conjunction 486 // rules applies (viz. x/c = x*(1/c) as far as lattice 487 // construction is concerned). 488 assert(!maybeZero(tensorExps[e].children.e1)); 489 return takeConj(kind, // take binary conjunction 490 buildLattices(tensorExps[e].children.e0, i), 491 buildLattices(tensorExps[e].children.e1, i)); 492 case kAddF: 493 case kAddI: 494 case kSubF: 495 case kSubI: 496 case kOrI: 497 case kXorI: 498 // An additive operation needs to be performed 499 // for the disjunction of sparse iteration spaces. 500 // 501 // x+y|!y | y | x-y|!y | y | 502 // ---+---+---+ ---+---+---+ 503 // !x | 0 | y | !x | 0 |-y | 504 // x | x |x+y| x | x |x-y| 505 return takeDisj(kind, // take binary disjunction 506 buildLattices(tensorExps[e].children.e0, i), 507 buildLattices(tensorExps[e].children.e1, i)); 508 case kShrS: 509 case kShrU: 510 case kShlI: 511 // A shift operation by an invariant amount (viz. tensor expressions 512 // can only occur at the left-hand-side of the operator) can be handled 513 // with the conjuction rule. 514 assert(isInvariant(tensorExps[e].children.e1)); 515 return takeConj(kind, // take binary conjunction 516 buildLattices(tensorExps[e].children.e0, i), 517 buildLattices(tensorExps[e].children.e1, i)); 518 } 519 llvm_unreachable("unexpected expression kind"); 520 } 521 522 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 523 Operation *yield = op.region().front().getTerminator(); 524 return buildTensorExp(op, yield->getOperand(0)); 525 } 526 527 /// Only returns false if we are certain this is a nonzero. 528 bool Merger::maybeZero(unsigned e) const { 529 if (tensorExps[e].kind == kInvariant) { 530 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 531 return c.value() == 0; 532 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 533 return c.value().isZero(); 534 } 535 return true; 536 } 537 538 bool Merger::isInvariant(unsigned e) const { 539 return tensorExps[e].kind == kInvariant; 540 } 541 542 Type Merger::inferType(unsigned e, Value src) { 543 // Obtain the destination type from the cast node. 544 Type dtp = tensorExps[e].val.getType(); 545 // Inspect source type. For vector types, apply the same 546 // vectorization to the destination type. 547 if (auto vtp = src.getType().dyn_cast<VectorType>()) 548 return VectorType::get(vtp.getNumElements(), dtp); 549 return dtp; 550 } 551 552 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 553 if (auto arg = v.dyn_cast<BlockArgument>()) { 554 unsigned argN = arg.getArgNumber(); 555 // Any argument of the generic op that is not marked as a scalar 556 // argument is considered a tensor, indexed by the implicit loop 557 // bounds. This includes rank-0 tensor arguments. 558 if (arg.getOwner()->getParentOp() == op) { 559 OpOperand *t = op.getInputAndOutputOperands()[argN]; 560 if (!op.isScalar(t)) 561 return addExp(kTensor, argN); 562 v = t->get(); // get scalar value 563 } 564 // Any other argument (marked as scalar argument for the generic op 565 // or belonging to an enveloping op) is considered invariant. 566 return addExp(kInvariant, v); 567 } 568 // Something defined outside is invariant. 569 Operation *def = v.getDefiningOp(); 570 if (def->getBlock() != &op.region().front()) 571 return addExp(kInvariant, v); 572 // Construct unary operations if subexpression can be built. 573 if (def->getNumOperands() == 1) { 574 auto x = buildTensorExp(op, def->getOperand(0)); 575 if (x.hasValue()) { 576 unsigned e = x.getValue(); 577 if (isa<math::AbsOp>(def)) 578 return addExp(kAbsF, e); 579 if (isa<math::CeilOp>(def)) 580 return addExp(kCeilF, e); 581 if (isa<math::FloorOp>(def)) 582 return addExp(kFloorF, e); 583 if (isa<arith::NegFOp>(def)) 584 return addExp(kNegF, e); // no negi in std 585 if (isa<arith::TruncFOp>(def)) 586 return addExp(kTruncF, e, v); 587 if (isa<arith::ExtFOp>(def)) 588 return addExp(kExtF, e, v); 589 if (isa<arith::FPToSIOp>(def)) 590 return addExp(kCastFS, e, v); 591 if (isa<arith::FPToUIOp>(def)) 592 return addExp(kCastFU, e, v); 593 if (isa<arith::SIToFPOp>(def)) 594 return addExp(kCastSF, e, v); 595 if (isa<arith::UIToFPOp>(def)) 596 return addExp(kCastUF, e, v); 597 if (isa<arith::ExtSIOp>(def)) 598 return addExp(kCastS, e, v); 599 if (isa<arith::ExtUIOp>(def)) 600 return addExp(kCastU, e, v); 601 if (isa<arith::TruncIOp>(def)) 602 return addExp(kTruncI, e, v); 603 if (isa<arith::BitcastOp>(def)) 604 return addExp(kBitCast, e, v); 605 } 606 } 607 // Construct binary operations if subexpressions can be built. 608 // See buildLattices() for an explanation of rejecting certain 609 // division and shift operations 610 if (def->getNumOperands() == 2) { 611 auto x = buildTensorExp(op, def->getOperand(0)); 612 auto y = buildTensorExp(op, def->getOperand(1)); 613 if (x.hasValue() && y.hasValue()) { 614 unsigned e0 = x.getValue(); 615 unsigned e1 = y.getValue(); 616 if (isa<arith::MulFOp>(def)) 617 return addExp(kMulF, e0, e1); 618 if (isa<arith::MulIOp>(def)) 619 return addExp(kMulI, e0, e1); 620 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 621 return addExp(kDivF, e0, e1); 622 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 623 return addExp(kDivS, e0, e1); 624 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 625 return addExp(kDivU, e0, e1); 626 if (isa<arith::AddFOp>(def)) 627 return addExp(kAddF, e0, e1); 628 if (isa<arith::AddIOp>(def)) 629 return addExp(kAddI, e0, e1); 630 if (isa<arith::SubFOp>(def)) 631 return addExp(kSubF, e0, e1); 632 if (isa<arith::SubIOp>(def)) 633 return addExp(kSubI, e0, e1); 634 if (isa<arith::AndIOp>(def)) 635 return addExp(kAndI, e0, e1); 636 if (isa<arith::OrIOp>(def)) 637 return addExp(kOrI, e0, e1); 638 if (isa<arith::XOrIOp>(def)) 639 return addExp(kXorI, e0, e1); 640 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 641 return addExp(kShrS, e0, e1); 642 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 643 return addExp(kShrU, e0, e1); 644 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 645 return addExp(kShlI, e0, e1); 646 } 647 } 648 // Cannot build. 649 return None; 650 } 651 652 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, 653 Value v0, Value v1) { 654 switch (tensorExps[e].kind) { 655 case kTensor: 656 case kInvariant: 657 llvm_unreachable("unexpected non-op"); 658 // Unary ops. 659 case kAbsF: 660 return rewriter.create<math::AbsOp>(loc, v0); 661 case kCeilF: 662 return rewriter.create<math::CeilOp>(loc, v0); 663 case kFloorF: 664 return rewriter.create<math::FloorOp>(loc, v0); 665 case kNegF: 666 return rewriter.create<arith::NegFOp>(loc, v0); 667 case kNegI: // no negi in std 668 return rewriter.create<arith::SubIOp>( 669 loc, 670 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 671 rewriter.getZeroAttr(v0.getType())), 672 v0); 673 case kTruncF: 674 return rewriter.create<arith::TruncFOp>(loc, v0, inferType(e, v0)); 675 case kExtF: 676 return rewriter.create<arith::ExtFOp>(loc, v0, inferType(e, v0)); 677 case kCastFS: 678 return rewriter.create<arith::FPToSIOp>(loc, v0, inferType(e, v0)); 679 case kCastFU: 680 return rewriter.create<arith::FPToUIOp>(loc, v0, inferType(e, v0)); 681 case kCastSF: 682 return rewriter.create<arith::SIToFPOp>(loc, v0, inferType(e, v0)); 683 case kCastUF: 684 return rewriter.create<arith::UIToFPOp>(loc, v0, inferType(e, v0)); 685 case kCastS: 686 return rewriter.create<arith::ExtSIOp>(loc, v0, inferType(e, v0)); 687 case kCastU: 688 return rewriter.create<arith::ExtUIOp>(loc, v0, inferType(e, v0)); 689 case kTruncI: 690 return rewriter.create<arith::TruncIOp>(loc, v0, inferType(e, v0)); 691 case kBitCast: 692 return rewriter.create<arith::BitcastOp>(loc, v0, inferType(e, v0)); 693 // Binary ops. 694 case kMulF: 695 return rewriter.create<arith::MulFOp>(loc, v0, v1); 696 case kMulI: 697 return rewriter.create<arith::MulIOp>(loc, v0, v1); 698 case kDivF: 699 return rewriter.create<arith::DivFOp>(loc, v0, v1); 700 case kDivS: 701 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 702 case kDivU: 703 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 704 case kAddF: 705 return rewriter.create<arith::AddFOp>(loc, v0, v1); 706 case kAddI: 707 return rewriter.create<arith::AddIOp>(loc, v0, v1); 708 case kSubF: 709 return rewriter.create<arith::SubFOp>(loc, v0, v1); 710 case kSubI: 711 return rewriter.create<arith::SubIOp>(loc, v0, v1); 712 case kAndI: 713 return rewriter.create<arith::AndIOp>(loc, v0, v1); 714 case kOrI: 715 return rewriter.create<arith::OrIOp>(loc, v0, v1); 716 case kXorI: 717 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 718 case kShrS: 719 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 720 case kShrU: 721 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 722 case kShlI: 723 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 724 } 725 llvm_unreachable("unexpected expression kind in build"); 726 } 727 728 } // namespace sparse_tensor 729 } // namespace mlir 730