1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 
12 #include "mlir/IR/Operation.h"
13 #include "llvm/Support/Debug.h"
14 
15 namespace mlir {
16 namespace sparse_tensor {
17 
18 //===----------------------------------------------------------------------===//
19 // Constructors.
20 //===----------------------------------------------------------------------===//
21 
22 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
23     : kind(k), val(v) {
24   switch (kind) {
25   case kTensor:
26     assert(x != -1u && y == -1u && !v);
27     tensor = x;
28     break;
29   case kInvariant:
30     assert(x == -1u && y == -1u && v);
31     break;
32   case kAbsF:
33   case kCeilF:
34   case kFloorF:
35   case kNegF:
36   case kNegI:
37     assert(x != -1u && y == -1u && !v);
38     children.e0 = x;
39     children.e1 = y;
40     break;
41   case kTruncF:
42   case kExtF:
43   case kCastFS:
44   case kCastFU:
45   case kCastSF:
46   case kCastUF:
47   case kCastS:
48   case kCastU:
49   case kTruncI:
50   case kBitCast:
51     assert(x != -1u && y == -1u && v);
52     children.e0 = x;
53     children.e1 = y;
54     break;
55   default:
56     assert(x != -1u && y != -1u && !v);
57     children.e0 = x;
58     children.e1 = y;
59     break;
60   }
61 }
62 
63 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
64     : bits(n, false), simple(), exp(e) {
65   bits.set(b);
66 }
67 
68 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
69     : bits(b), simple(), exp(e) {}
70 
71 //===----------------------------------------------------------------------===//
72 // Lattice methods.
73 //===----------------------------------------------------------------------===//
74 
75 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
76   unsigned e = tensorExps.size();
77   tensorExps.push_back(TensorExp(k, e0, e1, v));
78   return e;
79 }
80 
81 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
82   assert(t < numTensors && i < numLoops);
83   unsigned p = latPoints.size();
84   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
85   return p;
86 }
87 
88 unsigned Merger::addSet() {
89   unsigned s = latSets.size();
90   latSets.emplace_back(SmallVector<unsigned, 16>());
91   return s;
92 }
93 
94 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
95   unsigned p = latPoints.size();
96   llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
97   nb |= latPoints[p1].bits;
98   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
99   latPoints.push_back(LatPoint(nb, e));
100   return p;
101 }
102 
103 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
104   unsigned s = addSet();
105   for (unsigned p0 : latSets[s0])
106     for (unsigned p1 : latSets[s1])
107       latSets[s].push_back(conjLatPoint(kind, p0, p1));
108   return s;
109 }
110 
111 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
112   unsigned s = takeConj(kind, s0, s1);
113   // Followed by all in s0.
114   for (unsigned p : latSets[s0])
115     latSets[s].push_back(p);
116   // Map binary 0-y to unary -y.
117   if (kind == kSubF)
118     s1 = mapSet(kNegF, s1);
119   else if (kind == kSubI)
120     s1 = mapSet(kNegI, s1);
121   // Followed by all in s1.
122   for (unsigned p : latSets[s1])
123     latSets[s].push_back(p);
124   return s;
125 }
126 
127 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) {
128   assert(kAbsF <= kind && kind <= kBitCast);
129   unsigned s = addSet();
130   for (unsigned p : latSets[s0]) {
131     unsigned e = addExp(kind, latPoints[p].exp, v);
132     latPoints.push_back(LatPoint(latPoints[p].bits, e));
133     latSets[s].push_back(latPoints.size() - 1);
134   }
135   return s;
136 }
137 
138 unsigned Merger::optimizeSet(unsigned s0) {
139   unsigned s = addSet();
140   assert(latSets[s0].size() != 0);
141   unsigned p0 = latSets[s0][0];
142   for (unsigned p1 : latSets[s0]) {
143     bool add = true;
144     if (p0 != p1) {
145       // Is this a straightforward copy?
146       unsigned e = latPoints[p1].exp;
147       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
148         continue;
149       // Conjunction already covered?
150       for (unsigned p2 : latSets[s]) {
151         assert(!latGT(p1, p2)); // Lj => Li would be bad
152         if (onlyDenseDiff(p2, p1)) {
153           add = false;
154           break;
155         }
156       }
157       assert(!add || latGT(p0, p1));
158     }
159     if (add)
160       latSets[s].push_back(p1);
161   }
162   for (unsigned p : latSets[s])
163     latPoints[p].simple = simplifyCond(s, p);
164   return s;
165 }
166 
167 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
168   // First determine if this lattice point is a *singleton*, i.e.,
169   // the last point in a lattice, no other is less than this one.
170   bool isSingleton = true;
171   for (unsigned p1 : latSets[s0]) {
172     if (p0 != p1 && latGT(p0, p1)) {
173       isSingleton = false;
174       break;
175     }
176   }
177   // Now apply the two basic rules.
178   llvm::BitVector simple = latPoints[p0].bits;
179   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
180   for (unsigned b = 0, be = simple.size(); b < be; b++) {
181     if (simple[b] && !isDim(b, kSparse)) {
182       if (reset)
183         simple.reset(b);
184       reset = true;
185     }
186   }
187   return simple;
188 }
189 
190 bool Merger::latGT(unsigned i, unsigned j) const {
191   const llvm::BitVector &bitsi = latPoints[i].bits;
192   const llvm::BitVector &bitsj = latPoints[j].bits;
193   assert(bitsi.size() == bitsj.size());
194   if (bitsi.count() > bitsj.count()) {
195     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
196       if (bitsj[b] && !bitsi[b])
197         return false;
198     return true;
199   }
200   return false;
201 }
202 
203 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
204   llvm::BitVector tmp = latPoints[j].bits;
205   tmp ^= latPoints[i].bits;
206   return !hasAnyDimOf(tmp, kSparse);
207 }
208 
209 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
210   for (unsigned b = 0, be = bits.size(); b < be; b++)
211     if (bits[b] && isDim(b, d))
212       return true;
213   return false;
214 }
215 
216 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
217   switch (tensorExps[e].kind) {
218   case kTensor:
219     return tensorExps[e].tensor == t;
220   case kAbsF:
221   case kCeilF:
222   case kFloorF:
223   case kNegF:
224   case kNegI:
225   case kTruncF:
226   case kExtF:
227   case kCastFS:
228   case kCastFU:
229   case kCastSF:
230   case kCastUF:
231   case kCastS:
232   case kCastU:
233   case kTruncI:
234   case kBitCast:
235     return isSingleCondition(t, tensorExps[e].children.e0);
236   case kDivF: // note: x / c only
237   case kDivS:
238   case kDivU:
239     assert(!maybeZero(tensorExps[e].children.e1));
240     return isSingleCondition(t, tensorExps[e].children.e0);
241   case kShrS: // note: x >> inv only
242   case kShrU:
243   case kShlI:
244     assert(isInvariant(tensorExps[e].children.e1));
245     return isSingleCondition(t, tensorExps[e].children.e0);
246   case kMulF:
247   case kMulI:
248   case kAndI:
249     if (isSingleCondition(t, tensorExps[e].children.e0))
250       return isSingleCondition(t, tensorExps[e].children.e1) ||
251              isInvariant(tensorExps[e].children.e1);
252     if (isSingleCondition(t, tensorExps[e].children.e1))
253       return isInvariant(tensorExps[e].children.e0);
254     return false;
255   case kAddF:
256   case kAddI:
257     return isSingleCondition(t, tensorExps[e].children.e0) &&
258            isSingleCondition(t, tensorExps[e].children.e1);
259   default:
260     return false;
261   }
262 }
263 
264 #ifndef NDEBUG
265 
266 //===----------------------------------------------------------------------===//
267 // Print methods (for debugging).
268 //===----------------------------------------------------------------------===//
269 
270 static const char *kindToOpSymbol(Kind kind) {
271   switch (kind) {
272   case kTensor:
273     return "tensor";
274   case kInvariant:
275     return "invariant";
276   case kAbsF:
277     return "abs";
278   case kCeilF:
279     return "ceil";
280   case kFloorF:
281     return "floor";
282   case kNegF:
283     return "-";
284   case kNegI:
285     return "-";
286   case kTruncF:
287   case kExtF:
288   case kCastFS:
289   case kCastFU:
290   case kCastSF:
291   case kCastUF:
292   case kCastS:
293   case kCastU:
294   case kTruncI:
295   case kBitCast:
296     return "cast";
297   case kMulF:
298     return "*";
299   case kMulI:
300     return "*";
301   case kDivF:
302     return "/";
303   case kDivS:
304     return "/";
305   case kDivU:
306     return "/";
307   case kAddF:
308     return "+";
309   case kAddI:
310     return "+";
311   case kSubF:
312     return "-";
313   case kSubI:
314     return "-";
315   case kAndI:
316     return "&";
317   case kOrI:
318     return "|";
319   case kXorI:
320     return "^";
321   case kShrS:
322     return "a>>";
323   case kShrU:
324     return ">>";
325   case kShlI:
326     return "<<";
327   }
328   llvm_unreachable("unexpected kind for symbol");
329 }
330 
331 void Merger::dumpExp(unsigned e) const {
332   switch (tensorExps[e].kind) {
333   case kTensor:
334     if (tensorExps[e].tensor == syntheticTensor)
335       llvm::dbgs() << "synthetic_";
336     else if (tensorExps[e].tensor == outTensor)
337       llvm::dbgs() << "output_";
338     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
339     break;
340   case kInvariant:
341     llvm::dbgs() << "invariant";
342     break;
343   case kAbsF:
344   case kCeilF:
345   case kFloorF:
346   case kNegF:
347   case kNegI:
348   case kTruncF:
349   case kExtF:
350   case kCastFS:
351   case kCastFU:
352   case kCastSF:
353   case kCastUF:
354   case kCastS:
355   case kCastU:
356   case kTruncI:
357   case kBitCast:
358     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
359     dumpExp(tensorExps[e].children.e0);
360     break;
361   default:
362     llvm::dbgs() << "(";
363     dumpExp(tensorExps[e].children.e0);
364     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
365     dumpExp(tensorExps[e].children.e1);
366     llvm::dbgs() << ")";
367   }
368 }
369 
370 void Merger::dumpLat(unsigned p) const {
371   llvm::dbgs() << "lat(";
372   dumpBits(latPoints[p].bits);
373   llvm::dbgs() << " :";
374   dumpBits(latPoints[p].simple);
375   llvm::dbgs() << " : ";
376   dumpExp(latPoints[p].exp);
377   llvm::dbgs() << " )\n";
378 }
379 
380 void Merger::dumpSet(unsigned s) const {
381   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
382   for (unsigned p : latSets[s]) {
383     llvm::dbgs() << "  ";
384     dumpLat(p);
385   }
386   llvm::dbgs() << "}\n";
387 }
388 
389 void Merger::dumpBits(const llvm::BitVector &bits) const {
390   for (unsigned b = 0, be = bits.size(); b < be; b++) {
391     if (bits[b]) {
392       unsigned t = tensor(b);
393       unsigned i = index(b);
394       llvm::dbgs() << " i_" << t << "_" << i << "_";
395       switch (dims[t][i]) {
396       case kSparse:
397         llvm::dbgs() << "S";
398         break;
399       case kDense:
400         llvm::dbgs() << "D";
401         break;
402       case kSingle:
403         llvm::dbgs() << "T";
404         break;
405       case kUndef:
406         llvm::dbgs() << "U";
407         break;
408       }
409     }
410   }
411 }
412 
413 #endif // NDEBUG
414 
415 //===----------------------------------------------------------------------===//
416 // Builder methods.
417 //===----------------------------------------------------------------------===//
418 
419 unsigned Merger::buildLattices(unsigned e, unsigned i) {
420   Kind kind = tensorExps[e].kind;
421   switch (kind) {
422   case kTensor:
423   case kInvariant: {
424     // Either the index is really used in the tensor expression, or it is
425     // set to the undefined index in that dimension. An invariant expression
426     // and a truly dynamic sparse output tensor are set to a synthetic tensor
427     // with undefined indices only to ensure the iteration space is not
428     // skipped as a result of their contents.
429     unsigned s = addSet();
430     unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
431     if (hasSparseOut && t == outTensor)
432       t = syntheticTensor;
433     latSets[s].push_back(addLat(t, i, e));
434     return s;
435   }
436   case kAbsF:
437   case kCeilF:
438   case kFloorF:
439   case kNegF:
440   case kNegI:
441   case kTruncF:
442   case kExtF:
443   case kCastFS:
444   case kCastFU:
445   case kCastSF:
446   case kCastUF:
447   case kCastS:
448   case kCastU:
449   case kTruncI:
450   case kBitCast:
451     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
452     // lattice set of the operand through the operator into a new set.
453     //
454     //  -y|!y | y |
455     //  --+---+---+
456     //    | 0 |-y |
457     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
458                   tensorExps[e].val);
459   case kMulF:
460   case kMulI:
461   case kAndI:
462     // A multiplicative operation only needs to be performed
463     // for the conjunction of sparse iteration spaces.
464     //
465     //  x*y|!y | y |
466     //  ---+---+---+
467     //  !x | 0 | 0 |
468     //   x | 0 |x*y|
469     return takeConj(kind, // take binary conjunction
470                     buildLattices(tensorExps[e].children.e0, i),
471                     buildLattices(tensorExps[e].children.e1, i));
472   case kDivF:
473   case kDivS:
474   case kDivU:
475     // A division is tricky, since 0/0, 0/c, c/0 all have
476     // specific outcomes for floating-point and integers.
477     // Thus, we need to traverse the full iteration space.
478     //
479     //  x/y|!y | y |
480     //  ---+---+---+
481     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
482     //   x |x/0|x/y|  INT: x/0=exception for any x
483     //
484     // TODO: for now we "fixed" this by only accepting x/c cases
485     //       during expression building, so that the conjunction
486     //       rules applies (viz. x/c = x*(1/c) as far as lattice
487     //       construction is concerned).
488     assert(!maybeZero(tensorExps[e].children.e1));
489     return takeConj(kind, // take binary conjunction
490                     buildLattices(tensorExps[e].children.e0, i),
491                     buildLattices(tensorExps[e].children.e1, i));
492   case kAddF:
493   case kAddI:
494   case kSubF:
495   case kSubI:
496   case kOrI:
497   case kXorI:
498     // An additive operation needs to be performed
499     // for the disjunction of sparse iteration spaces.
500     //
501     //  x+y|!y | y |    x-y|!y | y |
502     //  ---+---+---+    ---+---+---+
503     //  !x | 0 | y |    !x | 0 |-y |
504     //   x | x |x+y|     x | x |x-y|
505     return takeDisj(kind, // take binary disjunction
506                     buildLattices(tensorExps[e].children.e0, i),
507                     buildLattices(tensorExps[e].children.e1, i));
508   case kShrS:
509   case kShrU:
510   case kShlI:
511     // A shift operation by an invariant amount (viz. tensor expressions
512     // can only occur at the left-hand-side of the operator) can be handled
513     // with the conjuction rule.
514     assert(isInvariant(tensorExps[e].children.e1));
515     return takeConj(kind, // take binary conjunction
516                     buildLattices(tensorExps[e].children.e0, i),
517                     buildLattices(tensorExps[e].children.e1, i));
518   }
519   llvm_unreachable("unexpected expression kind");
520 }
521 
522 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
523   Operation *yield = op.region().front().getTerminator();
524   return buildTensorExp(op, yield->getOperand(0));
525 }
526 
527 /// Only returns false if we are certain this is a nonzero.
528 bool Merger::maybeZero(unsigned e) const {
529   if (tensorExps[e].kind == kInvariant) {
530     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
531       return c.value() == 0;
532     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
533       return c.value().isZero();
534   }
535   return true;
536 }
537 
538 bool Merger::isInvariant(unsigned e) const {
539   return tensorExps[e].kind == kInvariant;
540 }
541 
542 Type Merger::inferType(unsigned e, Value src) {
543   // Obtain the destination type from the cast node.
544   Type dtp = tensorExps[e].val.getType();
545   // Inspect source type. For vector types, apply the same
546   // vectorization to the destination type.
547   if (auto vtp = src.getType().dyn_cast<VectorType>())
548     return VectorType::get(vtp.getNumElements(), dtp);
549   return dtp;
550 }
551 
552 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
553   if (auto arg = v.dyn_cast<BlockArgument>()) {
554     unsigned argN = arg.getArgNumber();
555     // Any argument of the generic op that is not marked as a scalar
556     // argument is considered a tensor, indexed by the implicit loop
557     // bounds. This includes rank-0 tensor arguments.
558     if (arg.getOwner()->getParentOp() == op) {
559       OpOperand *t = op.getInputAndOutputOperands()[argN];
560       if (!op.isScalar(t))
561         return addExp(kTensor, argN);
562       v = t->get(); // get scalar value
563     }
564     // Any other argument (marked as scalar argument for the generic op
565     // or belonging to an enveloping op) is considered invariant.
566     return addExp(kInvariant, v);
567   }
568   // Something defined outside is invariant.
569   Operation *def = v.getDefiningOp();
570   if (def->getBlock() != &op.region().front())
571     return addExp(kInvariant, v);
572   // Construct unary operations if subexpression can be built.
573   if (def->getNumOperands() == 1) {
574     auto x = buildTensorExp(op, def->getOperand(0));
575     if (x.hasValue()) {
576       unsigned e = x.getValue();
577       if (isa<math::AbsOp>(def))
578         return addExp(kAbsF, e);
579       if (isa<math::CeilOp>(def))
580         return addExp(kCeilF, e);
581       if (isa<math::FloorOp>(def))
582         return addExp(kFloorF, e);
583       if (isa<arith::NegFOp>(def))
584         return addExp(kNegF, e); // no negi in std
585       if (isa<arith::TruncFOp>(def))
586         return addExp(kTruncF, e, v);
587       if (isa<arith::ExtFOp>(def))
588         return addExp(kExtF, e, v);
589       if (isa<arith::FPToSIOp>(def))
590         return addExp(kCastFS, e, v);
591       if (isa<arith::FPToUIOp>(def))
592         return addExp(kCastFU, e, v);
593       if (isa<arith::SIToFPOp>(def))
594         return addExp(kCastSF, e, v);
595       if (isa<arith::UIToFPOp>(def))
596         return addExp(kCastUF, e, v);
597       if (isa<arith::ExtSIOp>(def))
598         return addExp(kCastS, e, v);
599       if (isa<arith::ExtUIOp>(def))
600         return addExp(kCastU, e, v);
601       if (isa<arith::TruncIOp>(def))
602         return addExp(kTruncI, e, v);
603       if (isa<arith::BitcastOp>(def))
604         return addExp(kBitCast, e, v);
605     }
606   }
607   // Construct binary operations if subexpressions can be built.
608   // See buildLattices() for an explanation of rejecting certain
609   // division and shift operations
610   if (def->getNumOperands() == 2) {
611     auto x = buildTensorExp(op, def->getOperand(0));
612     auto y = buildTensorExp(op, def->getOperand(1));
613     if (x.hasValue() && y.hasValue()) {
614       unsigned e0 = x.getValue();
615       unsigned e1 = y.getValue();
616       if (isa<arith::MulFOp>(def))
617         return addExp(kMulF, e0, e1);
618       if (isa<arith::MulIOp>(def))
619         return addExp(kMulI, e0, e1);
620       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
621         return addExp(kDivF, e0, e1);
622       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
623         return addExp(kDivS, e0, e1);
624       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
625         return addExp(kDivU, e0, e1);
626       if (isa<arith::AddFOp>(def))
627         return addExp(kAddF, e0, e1);
628       if (isa<arith::AddIOp>(def))
629         return addExp(kAddI, e0, e1);
630       if (isa<arith::SubFOp>(def))
631         return addExp(kSubF, e0, e1);
632       if (isa<arith::SubIOp>(def))
633         return addExp(kSubI, e0, e1);
634       if (isa<arith::AndIOp>(def))
635         return addExp(kAndI, e0, e1);
636       if (isa<arith::OrIOp>(def))
637         return addExp(kOrI, e0, e1);
638       if (isa<arith::XOrIOp>(def))
639         return addExp(kXorI, e0, e1);
640       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
641         return addExp(kShrS, e0, e1);
642       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
643         return addExp(kShrU, e0, e1);
644       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
645         return addExp(kShlI, e0, e1);
646     }
647   }
648   // Cannot build.
649   return None;
650 }
651 
652 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
653                        Value v0, Value v1) {
654   switch (tensorExps[e].kind) {
655   case kTensor:
656   case kInvariant:
657     llvm_unreachable("unexpected non-op");
658   // Unary ops.
659   case kAbsF:
660     return rewriter.create<math::AbsOp>(loc, v0);
661   case kCeilF:
662     return rewriter.create<math::CeilOp>(loc, v0);
663   case kFloorF:
664     return rewriter.create<math::FloorOp>(loc, v0);
665   case kNegF:
666     return rewriter.create<arith::NegFOp>(loc, v0);
667   case kNegI: // no negi in std
668     return rewriter.create<arith::SubIOp>(
669         loc,
670         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
671                                            rewriter.getZeroAttr(v0.getType())),
672         v0);
673   case kTruncF:
674     return rewriter.create<arith::TruncFOp>(loc, v0, inferType(e, v0));
675   case kExtF:
676     return rewriter.create<arith::ExtFOp>(loc, v0, inferType(e, v0));
677   case kCastFS:
678     return rewriter.create<arith::FPToSIOp>(loc, v0, inferType(e, v0));
679   case kCastFU:
680     return rewriter.create<arith::FPToUIOp>(loc, v0, inferType(e, v0));
681   case kCastSF:
682     return rewriter.create<arith::SIToFPOp>(loc, v0, inferType(e, v0));
683   case kCastUF:
684     return rewriter.create<arith::UIToFPOp>(loc, v0, inferType(e, v0));
685   case kCastS:
686     return rewriter.create<arith::ExtSIOp>(loc, v0, inferType(e, v0));
687   case kCastU:
688     return rewriter.create<arith::ExtUIOp>(loc, v0, inferType(e, v0));
689   case kTruncI:
690     return rewriter.create<arith::TruncIOp>(loc, v0, inferType(e, v0));
691   case kBitCast:
692     return rewriter.create<arith::BitcastOp>(loc, v0, inferType(e, v0));
693   // Binary ops.
694   case kMulF:
695     return rewriter.create<arith::MulFOp>(loc, v0, v1);
696   case kMulI:
697     return rewriter.create<arith::MulIOp>(loc, v0, v1);
698   case kDivF:
699     return rewriter.create<arith::DivFOp>(loc, v0, v1);
700   case kDivS:
701     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
702   case kDivU:
703     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
704   case kAddF:
705     return rewriter.create<arith::AddFOp>(loc, v0, v1);
706   case kAddI:
707     return rewriter.create<arith::AddIOp>(loc, v0, v1);
708   case kSubF:
709     return rewriter.create<arith::SubFOp>(loc, v0, v1);
710   case kSubI:
711     return rewriter.create<arith::SubIOp>(loc, v0, v1);
712   case kAndI:
713     return rewriter.create<arith::AndIOp>(loc, v0, v1);
714   case kOrI:
715     return rewriter.create<arith::OrIOp>(loc, v0, v1);
716   case kXorI:
717     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
718   case kShrS:
719     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
720   case kShrU:
721     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
722   case kShlI:
723     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
724   }
725   llvm_unreachable("unexpected expression kind in build");
726 }
727 
728 } // namespace sparse_tensor
729 } // namespace mlir
730