1 //===- Merger.cpp - Implementation of iteration lattices ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" 11 12 #include "mlir/IR/Operation.h" 13 #include "llvm/Support/Debug.h" 14 15 namespace mlir { 16 namespace sparse_tensor { 17 18 //===----------------------------------------------------------------------===// 19 // Constructors. 20 //===----------------------------------------------------------------------===// 21 22 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v) 23 : kind(k), val(v) { 24 switch (kind) { 25 case kTensor: 26 assert(x != -1u && y == -1u && !v); 27 tensor = x; 28 break; 29 case kInvariant: 30 assert(x == -1u && y == -1u && v); 31 break; 32 case kIndex: 33 assert(x != -1u && y == -1u && !v); 34 index = x; 35 break; 36 case kAbsF: 37 case kCeilF: 38 case kFloorF: 39 case kNegF: 40 case kNegI: 41 assert(x != -1u && y == -1u && !v); 42 children.e0 = x; 43 children.e1 = y; 44 break; 45 case kTruncF: 46 case kExtF: 47 case kCastFS: 48 case kCastFU: 49 case kCastSF: 50 case kCastUF: 51 case kCastS: 52 case kCastU: 53 case kCastIdx: 54 case kTruncI: 55 case kBitCast: 56 assert(x != -1u && y == -1u && v); 57 children.e0 = x; 58 children.e1 = y; 59 break; 60 default: 61 assert(x != -1u && y != -1u && !v); 62 children.e0 = x; 63 children.e1 = y; 64 break; 65 } 66 } 67 68 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b) 69 : bits(n, false), simple(), exp(e) { 70 bits.set(b); 71 } 72 73 LatPoint::LatPoint(const BitVector &b, unsigned e) 74 : bits(b), simple(), exp(e) {} 75 76 //===----------------------------------------------------------------------===// 77 // Lattice methods. 78 //===----------------------------------------------------------------------===// 79 80 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) { 81 unsigned e = tensorExps.size(); 82 tensorExps.push_back(TensorExp(k, e0, e1, v)); 83 return e; 84 } 85 86 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) { 87 assert(t < numTensors && i < numLoops); 88 unsigned p = latPoints.size(); 89 latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t)); 90 return p; 91 } 92 93 unsigned Merger::addSet() { 94 unsigned s = latSets.size(); 95 latSets.emplace_back(SmallVector<unsigned, 16>()); 96 return s; 97 } 98 99 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) { 100 unsigned p = latPoints.size(); 101 BitVector nb = BitVector(latPoints[p0].bits); 102 nb |= latPoints[p1].bits; 103 unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp); 104 latPoints.push_back(LatPoint(nb, e)); 105 return p; 106 } 107 108 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) { 109 unsigned s = addSet(); 110 for (unsigned p0 : latSets[s0]) 111 for (unsigned p1 : latSets[s1]) 112 latSets[s].push_back(conjLatPoint(kind, p0, p1)); 113 return s; 114 } 115 116 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) { 117 unsigned s = takeConj(kind, s0, s1); 118 // Followed by all in s0. 119 for (unsigned p : latSets[s0]) 120 latSets[s].push_back(p); 121 // Map binary 0-y to unary -y. 122 if (kind == kSubF) 123 s1 = mapSet(kNegF, s1); 124 else if (kind == kSubI) 125 s1 = mapSet(kNegI, s1); 126 // Followed by all in s1. 127 for (unsigned p : latSets[s1]) 128 latSets[s].push_back(p); 129 return s; 130 } 131 132 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) { 133 assert(kAbsF <= kind && kind <= kBitCast); 134 unsigned s = addSet(); 135 for (unsigned p : latSets[s0]) { 136 unsigned e = addExp(kind, latPoints[p].exp, v); 137 latPoints.push_back(LatPoint(latPoints[p].bits, e)); 138 latSets[s].push_back(latPoints.size() - 1); 139 } 140 return s; 141 } 142 143 unsigned Merger::optimizeSet(unsigned s0) { 144 unsigned s = addSet(); 145 assert(!latSets[s0].empty()); 146 unsigned p0 = latSets[s0][0]; 147 for (unsigned p1 : latSets[s0]) { 148 bool add = true; 149 if (p0 != p1) { 150 // Is this a straightforward copy? 151 unsigned e = latPoints[p1].exp; 152 if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor) 153 continue; 154 // Conjunction already covered? 155 for (unsigned p2 : latSets[s]) { 156 assert(!latGT(p1, p2)); // Lj => Li would be bad 157 if (onlyDenseDiff(p2, p1)) { 158 add = false; 159 break; 160 } 161 } 162 assert(!add || latGT(p0, p1)); 163 } 164 if (add) 165 latSets[s].push_back(p1); 166 } 167 for (unsigned p : latSets[s]) 168 latPoints[p].simple = simplifyCond(s, p); 169 return s; 170 } 171 172 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) { 173 // First determine if this lattice point is a *singleton*, i.e., 174 // the last point in a lattice, no other is less than this one. 175 bool isSingleton = true; 176 for (unsigned p1 : latSets[s0]) { 177 if (p0 != p1 && latGT(p0, p1)) { 178 isSingleton = false; 179 break; 180 } 181 } 182 // Now apply the two basic rules. 183 BitVector simple = latPoints[p0].bits; 184 bool reset = isSingleton && hasAnyDimOf(simple, kSparse); 185 for (unsigned b = 0, be = simple.size(); b < be; b++) { 186 if (simple[b] && !isDim(b, kSparse)) { 187 if (reset) 188 simple.reset(b); 189 reset = true; 190 } 191 } 192 return simple; 193 } 194 195 bool Merger::latGT(unsigned i, unsigned j) const { 196 const BitVector &bitsi = latPoints[i].bits; 197 const BitVector &bitsj = latPoints[j].bits; 198 assert(bitsi.size() == bitsj.size()); 199 if (bitsi.count() > bitsj.count()) { 200 for (unsigned b = 0, be = bitsj.size(); b < be; b++) 201 if (bitsj[b] && !bitsi[b]) 202 return false; 203 return true; 204 } 205 return false; 206 } 207 208 bool Merger::onlyDenseDiff(unsigned i, unsigned j) { 209 BitVector tmp = latPoints[j].bits; 210 tmp ^= latPoints[i].bits; 211 return !hasAnyDimOf(tmp, kSparse); 212 } 213 214 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const { 215 for (unsigned b = 0, be = bits.size(); b < be; b++) 216 if (bits[b] && isDim(b, d)) 217 return true; 218 return false; 219 } 220 221 bool Merger::isSingleCondition(unsigned t, unsigned e) const { 222 switch (tensorExps[e].kind) { 223 case kTensor: 224 return tensorExps[e].tensor == t; 225 case kAbsF: 226 case kCeilF: 227 case kFloorF: 228 case kNegF: 229 case kNegI: 230 case kTruncF: 231 case kExtF: 232 case kCastFS: 233 case kCastFU: 234 case kCastSF: 235 case kCastUF: 236 case kCastS: 237 case kCastU: 238 case kCastIdx: 239 case kTruncI: 240 case kBitCast: 241 return isSingleCondition(t, tensorExps[e].children.e0); 242 case kDivF: // note: x / c only 243 case kDivS: 244 case kDivU: 245 assert(!maybeZero(tensorExps[e].children.e1)); 246 return isSingleCondition(t, tensorExps[e].children.e0); 247 case kShrS: // note: x >> inv only 248 case kShrU: 249 case kShlI: 250 assert(isInvariant(tensorExps[e].children.e1)); 251 return isSingleCondition(t, tensorExps[e].children.e0); 252 case kMulF: 253 case kMulI: 254 case kAndI: 255 if (isSingleCondition(t, tensorExps[e].children.e0)) 256 return isSingleCondition(t, tensorExps[e].children.e1) || 257 isInvariant(tensorExps[e].children.e1); 258 if (isSingleCondition(t, tensorExps[e].children.e1)) 259 return isInvariant(tensorExps[e].children.e0); 260 return false; 261 case kAddF: 262 case kAddI: 263 return isSingleCondition(t, tensorExps[e].children.e0) && 264 isSingleCondition(t, tensorExps[e].children.e1); 265 default: 266 return false; 267 } 268 } 269 270 #ifndef NDEBUG 271 272 //===----------------------------------------------------------------------===// 273 // Print methods (for debugging). 274 //===----------------------------------------------------------------------===// 275 276 static const char *kindToOpSymbol(Kind kind) { 277 switch (kind) { 278 case kTensor: 279 return "tensor"; 280 case kInvariant: 281 return "invariant"; 282 case kIndex: 283 return "index"; 284 case kAbsF: 285 return "abs"; 286 case kCeilF: 287 return "ceil"; 288 case kFloorF: 289 return "floor"; 290 case kNegF: 291 return "-"; 292 case kNegI: 293 return "-"; 294 case kTruncF: 295 case kExtF: 296 case kCastFS: 297 case kCastFU: 298 case kCastSF: 299 case kCastUF: 300 case kCastS: 301 case kCastU: 302 case kCastIdx: 303 case kTruncI: 304 case kBitCast: 305 return "cast"; 306 case kMulF: 307 return "*"; 308 case kMulI: 309 return "*"; 310 case kDivF: 311 return "/"; 312 case kDivS: 313 return "/"; 314 case kDivU: 315 return "/"; 316 case kAddF: 317 return "+"; 318 case kAddI: 319 return "+"; 320 case kSubF: 321 return "-"; 322 case kSubI: 323 return "-"; 324 case kAndI: 325 return "&"; 326 case kOrI: 327 return "|"; 328 case kXorI: 329 return "^"; 330 case kShrS: 331 return "a>>"; 332 case kShrU: 333 return ">>"; 334 case kShlI: 335 return "<<"; 336 } 337 llvm_unreachable("unexpected kind for symbol"); 338 } 339 340 void Merger::dumpExp(unsigned e) const { 341 switch (tensorExps[e].kind) { 342 case kTensor: 343 if (tensorExps[e].tensor == syntheticTensor) 344 llvm::dbgs() << "synthetic_"; 345 else if (tensorExps[e].tensor == outTensor) 346 llvm::dbgs() << "output_"; 347 llvm::dbgs() << "tensor_" << tensorExps[e].tensor; 348 break; 349 case kInvariant: 350 llvm::dbgs() << "invariant"; 351 break; 352 case kIndex: 353 llvm::dbgs() << "index_" << tensorExps[e].index; 354 break; 355 case kAbsF: 356 case kCeilF: 357 case kFloorF: 358 case kNegF: 359 case kNegI: 360 case kTruncF: 361 case kExtF: 362 case kCastFS: 363 case kCastFU: 364 case kCastSF: 365 case kCastUF: 366 case kCastS: 367 case kCastU: 368 case kCastIdx: 369 case kTruncI: 370 case kBitCast: 371 llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " "; 372 dumpExp(tensorExps[e].children.e0); 373 break; 374 default: 375 llvm::dbgs() << "("; 376 dumpExp(tensorExps[e].children.e0); 377 llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " "; 378 dumpExp(tensorExps[e].children.e1); 379 llvm::dbgs() << ")"; 380 } 381 } 382 383 void Merger::dumpLat(unsigned p) const { 384 llvm::dbgs() << "lat("; 385 dumpBits(latPoints[p].bits); 386 llvm::dbgs() << " :"; 387 dumpBits(latPoints[p].simple); 388 llvm::dbgs() << " : "; 389 dumpExp(latPoints[p].exp); 390 llvm::dbgs() << " )\n"; 391 } 392 393 void Merger::dumpSet(unsigned s) const { 394 llvm::dbgs() << "{ #" << latSets[s].size() << "\n"; 395 for (unsigned p : latSets[s]) { 396 llvm::dbgs() << " "; 397 dumpLat(p); 398 } 399 llvm::dbgs() << "}\n"; 400 } 401 402 void Merger::dumpBits(const BitVector &bits) const { 403 for (unsigned b = 0, be = bits.size(); b < be; b++) { 404 if (bits[b]) { 405 unsigned t = tensor(b); 406 unsigned i = index(b); 407 llvm::dbgs() << " i_" << t << "_" << i << "_"; 408 switch (dims[t][i]) { 409 case kSparse: 410 llvm::dbgs() << "S"; 411 break; 412 case kDense: 413 llvm::dbgs() << "D"; 414 break; 415 case kSingle: 416 llvm::dbgs() << "T"; 417 break; 418 case kUndef: 419 llvm::dbgs() << "U"; 420 break; 421 } 422 } 423 } 424 } 425 426 #endif // NDEBUG 427 428 //===----------------------------------------------------------------------===// 429 // Builder methods. 430 //===----------------------------------------------------------------------===// 431 432 unsigned Merger::buildLattices(unsigned e, unsigned i) { 433 Kind kind = tensorExps[e].kind; 434 switch (kind) { 435 case kTensor: 436 case kInvariant: 437 case kIndex: { 438 // Either the index is really used in the tensor expression, or it is 439 // set to the undefined index in that dimension. An invariant expression, 440 // a proper index value, and a truly dynamic sparse output tensor are set 441 // to a synthetic tensor with undefined indices only to ensure the 442 // iteration space is not skipped as a result of their contents. 443 unsigned s = addSet(); 444 unsigned t = syntheticTensor; 445 if (kind == kTensor) { 446 t = tensorExps[e].tensor; 447 if (hasSparseOut && t == outTensor) 448 t = syntheticTensor; 449 } 450 latSets[s].push_back(addLat(t, i, e)); 451 return s; 452 } 453 case kAbsF: 454 case kCeilF: 455 case kFloorF: 456 case kNegF: 457 case kNegI: 458 case kTruncF: 459 case kExtF: 460 case kCastFS: 461 case kCastFU: 462 case kCastSF: 463 case kCastUF: 464 case kCastS: 465 case kCastU: 466 case kCastIdx: 467 case kTruncI: 468 case kBitCast: 469 // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the 470 // lattice set of the operand through the operator into a new set. 471 // 472 // -y|!y | y | 473 // --+---+---+ 474 // | 0 |-y | 475 return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), 476 tensorExps[e].val); 477 case kMulF: 478 case kMulI: 479 case kAndI: 480 // A multiplicative operation only needs to be performed 481 // for the conjunction of sparse iteration spaces. 482 // 483 // x*y|!y | y | 484 // ---+---+---+ 485 // !x | 0 | 0 | 486 // x | 0 |x*y| 487 return takeConj(kind, // take binary conjunction 488 buildLattices(tensorExps[e].children.e0, i), 489 buildLattices(tensorExps[e].children.e1, i)); 490 case kDivF: 491 case kDivS: 492 case kDivU: 493 // A division is tricky, since 0/0, 0/c, c/0 all have 494 // specific outcomes for floating-point and integers. 495 // Thus, we need to traverse the full iteration space. 496 // 497 // x/y|!y | y | 498 // ---+---+---+ 499 // !x |0/0|0/y| FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero 500 // x |x/0|x/y| INT: x/0=exception for any x 501 // 502 // TODO: for now we "fixed" this by only accepting x/c cases 503 // during expression building, so that the conjunction 504 // rules applies (viz. x/c = x*(1/c) as far as lattice 505 // construction is concerned). 506 assert(!maybeZero(tensorExps[e].children.e1)); 507 return takeConj(kind, // take binary conjunction 508 buildLattices(tensorExps[e].children.e0, i), 509 buildLattices(tensorExps[e].children.e1, i)); 510 case kAddF: 511 case kAddI: 512 case kSubF: 513 case kSubI: 514 case kOrI: 515 case kXorI: 516 // An additive operation needs to be performed 517 // for the disjunction of sparse iteration spaces. 518 // 519 // x+y|!y | y | x-y|!y | y | 520 // ---+---+---+ ---+---+---+ 521 // !x | 0 | y | !x | 0 |-y | 522 // x | x |x+y| x | x |x-y| 523 return takeDisj(kind, // take binary disjunction 524 buildLattices(tensorExps[e].children.e0, i), 525 buildLattices(tensorExps[e].children.e1, i)); 526 case kShrS: 527 case kShrU: 528 case kShlI: 529 // A shift operation by an invariant amount (viz. tensor expressions 530 // can only occur at the left-hand-side of the operator) can be handled 531 // with the conjuction rule. 532 assert(isInvariant(tensorExps[e].children.e1)); 533 return takeConj(kind, // take binary conjunction 534 buildLattices(tensorExps[e].children.e0, i), 535 buildLattices(tensorExps[e].children.e1, i)); 536 } 537 llvm_unreachable("unexpected expression kind"); 538 } 539 540 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) { 541 Operation *yield = op.region().front().getTerminator(); 542 return buildTensorExp(op, yield->getOperand(0)); 543 } 544 545 /// Only returns false if we are certain this is a nonzero. 546 bool Merger::maybeZero(unsigned e) const { 547 if (tensorExps[e].kind == kInvariant) { 548 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>()) 549 return c.value() == 0; 550 if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>()) 551 return c.value().isZero(); 552 } 553 return true; 554 } 555 556 bool Merger::isInvariant(unsigned e) const { 557 return tensorExps[e].kind == kInvariant; 558 } 559 560 Type Merger::inferType(unsigned e, Value src) { 561 // Obtain the destination type from the cast node. 562 Type dtp = tensorExps[e].val.getType(); 563 // Inspect source type. For vector types, apply the same 564 // vectorization to the destination type. 565 if (auto vtp = src.getType().dyn_cast<VectorType>()) 566 return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims()); 567 return dtp; 568 } 569 570 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) { 571 if (auto arg = v.dyn_cast<BlockArgument>()) { 572 unsigned argN = arg.getArgNumber(); 573 // Any argument of the generic op that is not marked as a scalar 574 // argument is considered a tensor, indexed by the implicit loop 575 // bounds. This includes rank-0 tensor arguments. 576 if (arg.getOwner()->getParentOp() == op) { 577 OpOperand *t = op.getInputAndOutputOperands()[argN]; 578 if (!op.isScalar(t)) 579 return addExp(kTensor, argN); 580 v = t->get(); // get scalar value 581 } 582 // Any other argument (marked as scalar argument for the generic op 583 // or belonging to an enveloping op) is considered invariant. 584 return addExp(kInvariant, v); 585 } 586 // Something defined outside is invariant. 587 Operation *def = v.getDefiningOp(); 588 if (def->getBlock() != &op.region().front()) 589 return addExp(kInvariant, v); 590 // Construct index operations. 591 if (def->getNumOperands() == 0) { 592 if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) 593 return addExp(kIndex, indexOp.dim()); 594 } 595 // Construct unary operations if subexpression can be built. 596 if (def->getNumOperands() == 1) { 597 auto x = buildTensorExp(op, def->getOperand(0)); 598 if (x.hasValue()) { 599 unsigned e = x.getValue(); 600 if (isa<math::AbsOp>(def)) 601 return addExp(kAbsF, e); 602 if (isa<math::CeilOp>(def)) 603 return addExp(kCeilF, e); 604 if (isa<math::FloorOp>(def)) 605 return addExp(kFloorF, e); 606 if (isa<arith::NegFOp>(def)) 607 return addExp(kNegF, e); // no negi in std 608 if (isa<arith::TruncFOp>(def)) 609 return addExp(kTruncF, e, v); 610 if (isa<arith::ExtFOp>(def)) 611 return addExp(kExtF, e, v); 612 if (isa<arith::FPToSIOp>(def)) 613 return addExp(kCastFS, e, v); 614 if (isa<arith::FPToUIOp>(def)) 615 return addExp(kCastFU, e, v); 616 if (isa<arith::SIToFPOp>(def)) 617 return addExp(kCastSF, e, v); 618 if (isa<arith::UIToFPOp>(def)) 619 return addExp(kCastUF, e, v); 620 if (isa<arith::ExtSIOp>(def)) 621 return addExp(kCastS, e, v); 622 if (isa<arith::ExtUIOp>(def)) 623 return addExp(kCastU, e, v); 624 if (isa<arith::IndexCastOp>(def)) 625 return addExp(kCastIdx, e, v); 626 if (isa<arith::TruncIOp>(def)) 627 return addExp(kTruncI, e, v); 628 if (isa<arith::BitcastOp>(def)) 629 return addExp(kBitCast, e, v); 630 } 631 } 632 // Construct binary operations if subexpressions can be built. 633 // See buildLattices() for an explanation of rejecting certain 634 // division and shift operations 635 if (def->getNumOperands() == 2) { 636 auto x = buildTensorExp(op, def->getOperand(0)); 637 auto y = buildTensorExp(op, def->getOperand(1)); 638 if (x.hasValue() && y.hasValue()) { 639 unsigned e0 = x.getValue(); 640 unsigned e1 = y.getValue(); 641 if (isa<arith::MulFOp>(def)) 642 return addExp(kMulF, e0, e1); 643 if (isa<arith::MulIOp>(def)) 644 return addExp(kMulI, e0, e1); 645 if (isa<arith::DivFOp>(def) && !maybeZero(e1)) 646 return addExp(kDivF, e0, e1); 647 if (isa<arith::DivSIOp>(def) && !maybeZero(e1)) 648 return addExp(kDivS, e0, e1); 649 if (isa<arith::DivUIOp>(def) && !maybeZero(e1)) 650 return addExp(kDivU, e0, e1); 651 if (isa<arith::AddFOp>(def)) 652 return addExp(kAddF, e0, e1); 653 if (isa<arith::AddIOp>(def)) 654 return addExp(kAddI, e0, e1); 655 if (isa<arith::SubFOp>(def)) 656 return addExp(kSubF, e0, e1); 657 if (isa<arith::SubIOp>(def)) 658 return addExp(kSubI, e0, e1); 659 if (isa<arith::AndIOp>(def)) 660 return addExp(kAndI, e0, e1); 661 if (isa<arith::OrIOp>(def)) 662 return addExp(kOrI, e0, e1); 663 if (isa<arith::XOrIOp>(def)) 664 return addExp(kXorI, e0, e1); 665 if (isa<arith::ShRSIOp>(def) && isInvariant(e1)) 666 return addExp(kShrS, e0, e1); 667 if (isa<arith::ShRUIOp>(def) && isInvariant(e1)) 668 return addExp(kShrU, e0, e1); 669 if (isa<arith::ShLIOp>(def) && isInvariant(e1)) 670 return addExp(kShlI, e0, e1); 671 } 672 } 673 // Cannot build. 674 return None; 675 } 676 677 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e, 678 Value v0, Value v1) { 679 switch (tensorExps[e].kind) { 680 case kTensor: 681 case kInvariant: 682 case kIndex: 683 llvm_unreachable("unexpected non-op"); 684 // Unary ops. 685 case kAbsF: 686 return rewriter.create<math::AbsOp>(loc, v0); 687 case kCeilF: 688 return rewriter.create<math::CeilOp>(loc, v0); 689 case kFloorF: 690 return rewriter.create<math::FloorOp>(loc, v0); 691 case kNegF: 692 return rewriter.create<arith::NegFOp>(loc, v0); 693 case kNegI: // no negi in std 694 return rewriter.create<arith::SubIOp>( 695 loc, 696 rewriter.create<arith::ConstantOp>(loc, v0.getType(), 697 rewriter.getZeroAttr(v0.getType())), 698 v0); 699 case kTruncF: 700 return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0); 701 case kExtF: 702 return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0); 703 case kCastFS: 704 return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0); 705 case kCastFU: 706 return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0); 707 case kCastSF: 708 return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0); 709 case kCastUF: 710 return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0); 711 case kCastS: 712 return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0); 713 case kCastU: 714 return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0); 715 case kCastIdx: 716 return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0); 717 case kTruncI: 718 return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0); 719 case kBitCast: 720 return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0); 721 // Binary ops. 722 case kMulF: 723 return rewriter.create<arith::MulFOp>(loc, v0, v1); 724 case kMulI: 725 return rewriter.create<arith::MulIOp>(loc, v0, v1); 726 case kDivF: 727 return rewriter.create<arith::DivFOp>(loc, v0, v1); 728 case kDivS: 729 return rewriter.create<arith::DivSIOp>(loc, v0, v1); 730 case kDivU: 731 return rewriter.create<arith::DivUIOp>(loc, v0, v1); 732 case kAddF: 733 return rewriter.create<arith::AddFOp>(loc, v0, v1); 734 case kAddI: 735 return rewriter.create<arith::AddIOp>(loc, v0, v1); 736 case kSubF: 737 return rewriter.create<arith::SubFOp>(loc, v0, v1); 738 case kSubI: 739 return rewriter.create<arith::SubIOp>(loc, v0, v1); 740 case kAndI: 741 return rewriter.create<arith::AndIOp>(loc, v0, v1); 742 case kOrI: 743 return rewriter.create<arith::OrIOp>(loc, v0, v1); 744 case kXorI: 745 return rewriter.create<arith::XOrIOp>(loc, v0, v1); 746 case kShrS: 747 return rewriter.create<arith::ShRSIOp>(loc, v0, v1); 748 case kShrU: 749 return rewriter.create<arith::ShRUIOp>(loc, v0, v1); 750 case kShlI: 751 return rewriter.create<arith::ShLIOp>(loc, v0, v1); 752 } 753 llvm_unreachable("unexpected expression kind in build"); 754 } 755 756 } // namespace sparse_tensor 757 } // namespace mlir 758