1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 12 #include "mlir/IR/Operation.h" 13 #include "llvm/Support/Debug.h" 14 15 namespace mlir { 16 namespace sparse_tensor { 17 18 //===----------------------------------------------------------------------===// 19 // Constructors. 20 //===----------------------------------------------------------------------===// 21 22 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) 23 : kind(k), val(v) { 24 switch (kind) { 25 case kTensor: 26 assert(x != -1u && y == -1u && !v); 27 tensor = x; 28 break; 29 case kInvariant: 30 assert(x == -1u && y == -1u && v); 31 break; 32 case kAbsF: 33 case kCeilF: 34 case kFloorF: 35 case kNegF: 36 case kNegI: 37 assert(x != -1u && y == -1u && !v); 38 children.e0 = x; 39 children.e1 = y; 40 break; 41 case kTruncF: 42 case kExtF: 43 case kCastFS: 44 case kCastFU: 45 case kCastSF: 46 case kCastUF: 47 case kCastS: 48 case kCastU: 49 case kTruncI: 50 case kBitCast: 51 assert(x != -1u && y == -1u && v); 52 children.e0 = x; 53 children.e1 = y; 54 break; 55 default: 56 assert(x != -1u && y != -1u && !v); 57 children.e0 = x; 58 children.e1 = y; 59 break; 60 } 61 } 62 63 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 64 : bits(n, false), simple(), exp(e) { 65 bits.set(b); 66 } 67 68 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e) 69 : bits(b), simple(), exp(e) {} 70 71 //===----------------------------------------------------------------------===// 72 // Lattice methods. 73 //===----------------------------------------------------------------------===// 74 75 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { 76 unsigned e = tensorExps.size(); 77 tensorExps.push_back(TensorExp(k, e0, e1, v)); 78 return e; 79 } 80 81 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 82 assert(t < numTensors && i < numLoops); 83 unsigned p = latPoints.size(); 84 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 85 return p; 86 } 87 88 unsigned Merger::addSet() { 89 unsigned s = latSets.size(); 90 latSets.emplace_back(SmallVector<unsigned, 16>()); 91 return s; 92 } 93 94 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 95 unsigned p = latPoints.size(); 96 llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits); 97 nb |= latPoints[p1].bits; 98 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 99 latPoints.push_back(LatPoint(nb, e)); 100 return p; 101 } 102 103 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { 104 unsigned s = addSet(); 105 for (unsigned p0 : latSets[s0]) 106 for (unsigned p1 : latSets[s1]) 107 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 108 return s; 109 } 110 111 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { 112 unsigned s = takeConj(kind, s0, s1); 113 // Followed by all in s0. 114 for (unsigned p : latSets[s0]) 115 latSets[s].push_back(p); 116 // Map binary 0-y to unary -y. 117 if (kind == kSubF) 118 s1 = mapSet(kNegF, s1); 119 else if (kind == kSubI) 120 s1 = mapSet(kNegI, s1); 121 // Followed by all in s1. 122 for (unsigned p : latSets[s1]) 123 latSets[s].push_back(p); 124 return s; 125 } 126 127 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) { 128 assert(kAbsF <= kind && kind <= kBitCast); 129 unsigned s = addSet(); 130 for (unsigned p : latSets[s0]) { 131 unsigned e = addExp(kind, latPoints[p].exp, v); 132 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 133 latSets[s].push_back(latPoints.size() - 1); 134 } 135 return s; 136 } 137 138 unsigned Merger::optimizeSet(unsigned s0) { 139 unsigned s = addSet(); 140 assert(latSets[s0].size() != 0); 141 unsigned p0 = latSets[s0][0]; 142 for (unsigned p1 : latSets[s0]) { 143 bool add = true; 144 if (p0 != p1) { 145 // Is this a straightforward copy? 146 unsigned e = latPoints[p1].exp; 147 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 148 continue; 149 // Conjunction already covered? 150 for (unsigned p2 : latSets[s]) { 151 assert(!latGT(p1, p2)); // Lj => Li would be bad 152 if (onlyDenseDiff(p2, p1)) { 153 add = false; 154 break; 155 } 156 } 157 assert(!add || latGT(p0, p1)); 158 } 159 if (add) 160 latSets[s].push_back(p1); 161 } 162 for (unsigned p : latSets[s]) 163 latPoints[p].simple = simplifyCond(s, p); 164 return s; 165 } 166 167 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 168 // First determine if this lattice point is a *singleton*, i.e., 169 // the last point in a lattice, no other is less than this one. 170 bool isSingleton = true; 171 for (unsigned p1 : latSets[s0]) { 172 if (p0 != p1 && latGT(p0, p1)) { 173 isSingleton = false; 174 break; 175 } 176 } 177 // Now apply the two basic rules. 178 llvm::BitVector simple = latPoints[p0].bits; 179 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 180 for (unsigned b = 0, be = simple.size(); b < be; b++) { 181 if (simple[b] && !isDim(b, kSparse)) { 182 if (reset) 183 simple.reset(b); 184 reset = true; 185 } 186 } 187 return simple; 188 } 189 190 bool Merger::latGT(unsigned i, unsigned j) const { 191 const llvm::BitVector &bitsi = latPoints[i].bits; 192 const llvm::BitVector &bitsj = latPoints[j].bits; 193 assert(bitsi.size() == bitsj.size()); 194 if (bitsi.count() > bitsj.count()) { 195 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 196 if (bitsj[b] && !bitsi[b]) 197 return false; 198 return true; 199 } 200 return false; 201 } 202 203 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 204 llvm::BitVector tmp = latPoints[j].bits; 205 tmp ^= latPoints[i].bits; 206 return !hasAnyDimOf(tmp, kSparse); 207 } 208 209 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const { 210 for (unsigned b = 0, be = bits.size(); b < be; b++) 211 if (bits[b] && isDim(b, d)) 212 return true; 213 return false; 214 } 215 216 bool Merger::isConjunction(unsigned t, unsigned e) const { 217 switch (tensorExps[e].kind) { 218 case kTensor: 219 return tensorExps[e].tensor == t; 220 case kAbsF: 221 case kCeilF: 222 case kFloorF: 223 case kNegF: 224 case kNegI: 225 case kTruncF: 226 case kExtF: 227 case kCastFS: 228 case kCastFU: 229 case kCastSF: 230 case kCastUF: 231 case kCastS: 232 case kCastU: 233 case kTruncI: 234 case kBitCast: 235 return isConjunction(t, tensorExps[e].children.e0); 236 case kDivF: // note: x / c only 237 case kDivS: 238 case kDivU: 239 assert(!maybeZero(tensorExps[e].children.e1)); 240 return isConjunction(t, tensorExps[e].children.e0); 241 case kShrS: // note: x >> inv only 242 case kShrU: 243 case kShlI: 244 assert(isInvariant(tensorExps[e].children.e1)); 245 return isConjunction(t, tensorExps[e].children.e0); 246 case kMulF: 247 case kMulI: 248 case kAndI: 249 return isConjunction(t, tensorExps[e].children.e0) || 250 isConjunction(t, tensorExps[e].children.e1); 251 default: 252 return false; 253 } 254 } 255 256 #ifndef NDEBUG 257 258 //===----------------------------------------------------------------------===// 259 // Print methods (for debugging). 260 //===----------------------------------------------------------------------===// 261 262 static const char *kindToOpSymbol(Kind kind) { 263 switch (kind) { 264 case kTensor: 265 return "tensor"; 266 case kInvariant: 267 return "invariant"; 268 case kAbsF: 269 return "abs"; 270 case kCeilF: 271 return "ceil"; 272 case kFloorF: 273 return "floor"; 274 case kNegF: 275 return "-"; 276 case kNegI: 277 return "-"; 278 case kTruncF: 279 case kExtF: 280 case kCastFS: 281 case kCastFU: 282 case kCastSF: 283 case kCastUF: 284 case kCastS: 285 case kCastU: 286 case kTruncI: 287 case kBitCast: 288 return "cast"; 289 case kMulF: 290 return "*"; 291 case kMulI: 292 return "*"; 293 case kDivF: 294 return "/"; 295 case kDivS: 296 return "/"; 297 case kDivU: 298 return "/"; 299 case kAddF: 300 return "+"; 301 case kAddI: 302 return "+"; 303 case kSubF: 304 return "-"; 305 case kSubI: 306 return "-"; 307 case kAndI: 308 return "&"; 309 case kOrI: 310 return "|"; 311 case kXorI: 312 return "^"; 313 case kShrS: 314 return "a>>"; 315 case kShrU: 316 return ">>"; 317 case kShlI: 318 return "<<"; 319 } 320 llvm_unreachable("unexpected kind for symbol"); 321 } 322 323 void Merger::dumpExp(unsigned e) const { 324 switch (tensorExps[e].kind) { 325 case kTensor: 326 if (tensorExps[e].tensor == syntheticTensor) 327 llvm::dbgs() << "synthetic_"; 328 else if (tensorExps[e].tensor == outTensor) 329 llvm::dbgs() << "output_"; 330 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 331 break; 332 case kInvariant: 333 llvm::dbgs() << "invariant"; 334 break; 335 case kAbsF: 336 case kCeilF: 337 case kFloorF: 338 case kNegF: 339 case kNegI: 340 case kTruncF: 341 case kExtF: 342 case kCastFS: 343 case kCastFU: 344 case kCastSF: 345 case kCastUF: 346 case kCastS: 347 case kCastU: 348 case kTruncI: 349 case kBitCast: 350 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 351 dumpExp(tensorExps[e].children.e0); 352 break; 353 default: 354 llvm::dbgs() << "("; 355 dumpExp(tensorExps[e].children.e0); 356 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 357 dumpExp(tensorExps[e].children.e1); 358 llvm::dbgs() << ")"; 359 } 360 } 361 362 void Merger::dumpLat(unsigned p) const { 363 llvm::dbgs() << "lat("; 364 dumpBits(latPoints[p].bits); 365 llvm::dbgs() << " :"; 366 dumpBits(latPoints[p].simple); 367 llvm::dbgs() << " : "; 368 dumpExp(latPoints[p].exp); 369 llvm::dbgs() << " )\n"; 370 } 371 372 void Merger::dumpSet(unsigned s) const { 373 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 374 for (unsigned p : latSets[s]) { 375 llvm::dbgs() << " "; 376 dumpLat(p); 377 } 378 llvm::dbgs() << "}\n"; 379 } 380 381 void Merger::dumpBits(const llvm::BitVector &bits) const { 382 for (unsigned b = 0, be = bits.size(); b < be; b++) { 383 if (bits[b]) { 384 unsigned t = tensor(b); 385 unsigned i = index(b); 386 llvm::dbgs() << " i_" << t << "_" << i << "_"; 387 switch (dims[t][i]) { 388 case kSparse: 389 llvm::dbgs() << "S"; 390 break; 391 case kDense: 392 llvm::dbgs() << "D"; 393 break; 394 case kSingle: 395 llvm::dbgs() << "T"; 396 break; 397 case kUndef: 398 llvm::dbgs() << "U"; 399 break; 400 } 401 } 402 } 403 } 404 405 #endif // NDEBUG 406 407 //===----------------------------------------------------------------------===// 408 // Builder methods. 409 //===----------------------------------------------------------------------===// 410 411 unsigned Merger::buildLattices(unsigned e, unsigned i) { 412 Kind kind = tensorExps[e].kind; 413 switch (kind) { 414 case kTensor: 415 case kInvariant: { 416 // Either the index is really used in the tensor expression, or it is 417 // set to the undefined index in that dimension. An invariant expression 418 // is set to a synthetic tensor with undefined indices only. 419 unsigned s = addSet(); 420 unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor; 421 latSets[s].push_back(addLat(t, i, e)); 422 return s; 423 } 424 case kAbsF: 425 case kCeilF: 426 case kFloorF: 427 case kNegF: 428 case kNegI: 429 case kTruncF: 430 case kExtF: 431 case kCastFS: 432 case kCastFU: 433 case kCastSF: 434 case kCastUF: 435 case kCastS: 436 case kCastU: 437 case kTruncI: 438 case kBitCast: 439 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 440 // lattice set of the operand through the operator into a new set. 441 // 442 // -y|!y | y | 443 // --+---+---+ 444 // | 0 |-y | 445 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 446 tensorExps[e].val); 447 case kMulF: 448 case kMulI: 449 case kAndI: 450 // A multiplicative operation only needs to be performed 451 // for the conjunction of sparse iteration spaces. 452 // 453 // x*y|!y | y | 454 // ---+---+---+ 455 // !x | 0 | 0 | 456 // x | 0 |x*y| 457 return takeConj(kind, // take binary conjunction 458 buildLattices(tensorExps[e].children.e0, i), 459 buildLattices(tensorExps[e].children.e1, i)); 460 case kDivF: 461 case kDivS: 462 case kDivU: 463 // A division is tricky, since 0/0, 0/c, c/0 all have 464 // specific outcomes for floating-point and integers. 465 // Thus, we need to traverse the full iteration space. 466 // 467 // x/y|!y | y | 468 // ---+---+---+ 469 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 470 // x |x/0|x/y| INT: x/0=exception for any x 471 // 472 // TODO: for now we "fixed" this by only accepting x/c cases 473 // during expression building, so that the conjunction 474 // rules applies (viz. x/c = x*(1/c) as far as lattice 475 // construction is concerned). 476 assert(!maybeZero(tensorExps[e].children.e1)); 477 return takeConj(kind, // take binary conjunction 478 buildLattices(tensorExps[e].children.e0, i), 479 buildLattices(tensorExps[e].children.e1, i)); 480 case kAddF: 481 case kAddI: 482 case kSubF: 483 case kSubI: 484 case kOrI: 485 case kXorI: 486 // An additive operation needs to be performed 487 // for the disjunction of sparse iteration spaces. 488 // 489 // x+y|!y | y | x-y|!y | y | 490 // ---+---+---+ ---+---+---+ 491 // !x | 0 | y | !x | 0 |-y | 492 // x | x |x+y| x | x |x-y| 493 return takeDisj(kind, // take binary disjunction 494 buildLattices(tensorExps[e].children.e0, i), 495 buildLattices(tensorExps[e].children.e1, i)); 496 case kShrS: 497 case kShrU: 498 case kShlI: 499 // A shift operation by an invariant amount (viz. tensor expressions 500 // can only occur at the left-hand-side of the operator) can be handled 501 // with the conjuction rule. 502 assert(isInvariant(tensorExps[e].children.e1)); 503 return takeConj(kind, // take binary conjunction 504 buildLattices(tensorExps[e].children.e0, i), 505 buildLattices(tensorExps[e].children.e1, i)); 506 } 507 llvm_unreachable("unexpected expression kind"); 508 } 509 510 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 511 Operation *yield = op.region().front().getTerminator(); 512 return buildTensorExp(op, yield->getOperand(0)); 513 } 514 515 /// Only returns false if we are certain this is a nonzero. 516 bool Merger::maybeZero(unsigned e) const { 517 if (tensorExps[e].kind == kInvariant) { 518 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 519 return c.value() == 0; 520 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 521 return c.value().isZero(); 522 } 523 return true; 524 } 525 526 bool Merger::isInvariant(unsigned e) const { 527 return tensorExps[e].kind == kInvariant; 528 } 529 530 Type Merger::inferType(unsigned e, Value src) { 531 // Obtain the destination type from the cast node. 532 Type dtp = tensorExps[e].val.getType(); 533 // Inspect source type. For vector types, apply the same 534 // vectorization to the destination type. 535 if (auto vtp = src.getType().dyn_cast<VectorType>()) 536 return VectorType::get(vtp.getNumElements(), dtp); 537 return dtp; 538 } 539 540 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 541 if (auto arg = v.dyn_cast<BlockArgument>()) { 542 unsigned argN = arg.getArgNumber(); 543 // Any argument of the generic op that is not marked as a scalar 544 // argument is considered a tensor, indexed by the implicit loop 545 // bounds. This includes rank-0 tensor arguments. 546 if (arg.getOwner()->getParentOp() == op) { 547 OpOperand *t = op.getInputAndOutputOperands()[argN]; 548 if (!op.isScalar(t)) 549 return addExp(kTensor, argN); 550 v = t->get(); // get scalar value 551 } 552 // Any other argument (marked as scalar argument for the generic op 553 // or belonging to an enveloping op) is considered invariant. 554 return addExp(kInvariant, v); 555 } 556 // Something defined outside is invariant. 557 Operation *def = v.getDefiningOp(); 558 if (def->getBlock() != &op.region().front()) 559 return addExp(kInvariant, v); 560 // Construct unary operations if subexpression can be built. 561 if (def->getNumOperands() == 1) { 562 auto x = buildTensorExp(op, def->getOperand(0)); 563 if (x.hasValue()) { 564 unsigned e = x.getValue(); 565 if (isa<math::AbsOp>(def)) 566 return addExp(kAbsF, e); 567 if (isa<math::CeilOp>(def)) 568 return addExp(kCeilF, e); 569 if (isa<math::FloorOp>(def)) 570 return addExp(kFloorF, e); 571 if (isa<arith::NegFOp>(def)) 572 return addExp(kNegF, e); // no negi in std 573 if (isa<arith::TruncFOp>(def)) 574 return addExp(kTruncF, e, v); 575 if (isa<arith::ExtFOp>(def)) 576 return addExp(kExtF, e, v); 577 if (isa<arith::FPToSIOp>(def)) 578 return addExp(kCastFS, e, v); 579 if (isa<arith::FPToUIOp>(def)) 580 return addExp(kCastFU, e, v); 581 if (isa<arith::SIToFPOp>(def)) 582 return addExp(kCastSF, e, v); 583 if (isa<arith::UIToFPOp>(def)) 584 return addExp(kCastUF, e, v); 585 if (isa<arith::ExtSIOp>(def)) 586 return addExp(kCastS, e, v); 587 if (isa<arith::ExtUIOp>(def)) 588 return addExp(kCastU, e, v); 589 if (isa<arith::TruncIOp>(def)) 590 return addExp(kTruncI, e, v); 591 if (isa<arith::BitcastOp>(def)) 592 return addExp(kBitCast, e, v); 593 } 594 } 595 // Construct binary operations if subexpressions can be built. 596 // TODO: see buildLattices() for an explanation of rejecting 597 // certain division and shift operations 598 if (def->getNumOperands() == 2) { 599 auto x = buildTensorExp(op, def->getOperand(0)); 600 auto y = buildTensorExp(op, def->getOperand(1)); 601 if (x.hasValue() && y.hasValue()) { 602 unsigned e0 = x.getValue(); 603 unsigned e1 = y.getValue(); 604 if (isa<arith::MulFOp>(def)) 605 return addExp(kMulF, e0, e1); 606 if (isa<arith::MulIOp>(def)) 607 return addExp(kMulI, e0, e1); 608 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 609 return addExp(kDivF, e0, e1); 610 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 611 return addExp(kDivS, e0, e1); 612 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 613 return addExp(kDivU, e0, e1); 614 if (isa<arith::AddFOp>(def)) 615 return addExp(kAddF, e0, e1); 616 if (isa<arith::AddIOp>(def)) 617 return addExp(kAddI, e0, e1); 618 if (isa<arith::SubFOp>(def)) 619 return addExp(kSubF, e0, e1); 620 if (isa<arith::SubIOp>(def)) 621 return addExp(kSubI, e0, e1); 622 if (isa<arith::AndIOp>(def)) 623 return addExp(kAndI, e0, e1); 624 if (isa<arith::OrIOp>(def)) 625 return addExp(kOrI, e0, e1); 626 if (isa<arith::XOrIOp>(def)) 627 return addExp(kXorI, e0, e1); 628 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 629 return addExp(kShrS, e0, e1); 630 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 631 return addExp(kShrU, e0, e1); 632 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 633 return addExp(kShlI, e0, e1); 634 } 635 } 636 // Cannot build. 637 return None; 638 } 639 640 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, 641 Value v0, Value v1) { 642 switch (tensorExps[e].kind) { 643 case kTensor: 644 case kInvariant: 645 llvm_unreachable("unexpected non-op"); 646 // Unary ops. 647 case kAbsF: 648 return rewriter.create<math::AbsOp>(loc, v0); 649 case kCeilF: 650 return rewriter.create<math::CeilOp>(loc, v0); 651 case kFloorF: 652 return rewriter.create<math::FloorOp>(loc, v0); 653 case kNegF: 654 return rewriter.create<arith::NegFOp>(loc, v0); 655 case kNegI: // no negi in std 656 return rewriter.create<arith::SubIOp>( 657 loc, 658 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 659 rewriter.getZeroAttr(v0.getType())), 660 v0); 661 case kTruncF: 662 return rewriter.create<arith::TruncFOp>(loc, v0, inferType(e, v0)); 663 case kExtF: 664 return rewriter.create<arith::ExtFOp>(loc, v0, inferType(e, v0)); 665 case kCastFS: 666 return rewriter.create<arith::FPToSIOp>(loc, v0, inferType(e, v0)); 667 case kCastFU: 668 return rewriter.create<arith::FPToUIOp>(loc, v0, inferType(e, v0)); 669 case kCastSF: 670 return rewriter.create<arith::SIToFPOp>(loc, v0, inferType(e, v0)); 671 case kCastUF: 672 return rewriter.create<arith::UIToFPOp>(loc, v0, inferType(e, v0)); 673 case kCastS: 674 return rewriter.create<arith::ExtSIOp>(loc, v0, inferType(e, v0)); 675 case kCastU: 676 return rewriter.create<arith::ExtUIOp>(loc, v0, inferType(e, v0)); 677 case kTruncI: 678 return rewriter.create<arith::TruncIOp>(loc, v0, inferType(e, v0)); 679 case kBitCast: 680 return rewriter.create<arith::BitcastOp>(loc, v0, inferType(e, v0)); 681 // Binary ops. 682 case kMulF: 683 return rewriter.create<arith::MulFOp>(loc, v0, v1); 684 case kMulI: 685 return rewriter.create<arith::MulIOp>(loc, v0, v1); 686 case kDivF: 687 return rewriter.create<arith::DivFOp>(loc, v0, v1); 688 case kDivS: 689 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 690 case kDivU: 691 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 692 case kAddF: 693 return rewriter.create<arith::AddFOp>(loc, v0, v1); 694 case kAddI: 695 return rewriter.create<arith::AddIOp>(loc, v0, v1); 696 case kSubF: 697 return rewriter.create<arith::SubFOp>(loc, v0, v1); 698 case kSubI: 699 return rewriter.create<arith::SubIOp>(loc, v0, v1); 700 case kAndI: 701 return rewriter.create<arith::AndIOp>(loc, v0, v1); 702 case kOrI: 703 return rewriter.create<arith::OrIOp>(loc, v0, v1); 704 case kXorI: 705 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 706 case kShrS: 707 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 708 case kShrU: 709 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 710 case kShlI: 711 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 712 } 713 llvm_unreachable("unexpected expression kind in build"); 714 } 715 716 } // namespace sparse_tensor 717 } // namespace mlir 718