1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 
11 #include "mlir/IR/Operation.h"
12 #include "llvm/Support/Debug.h"
13 
14 namespace mlir {
15 namespace sparse_tensor {
16 
17 //===----------------------------------------------------------------------===//
18 // Constructors.
19 //===----------------------------------------------------------------------===//
20 
21 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
22     : kind(k), val(v) {
23   switch (kind) {
24   case kTensor:
25     assert(x != -1u && y == -1u && !v);
26     tensor = x;
27     break;
28   case kInvariant:
29     assert(x == -1u && y == -1u && v);
30     break;
31   case kAbsF:
32   case kCeilF:
33   case kFloorF:
34   case kNegF:
35   case kNegI:
36     assert(x != -1u && y == -1u && !v);
37     children.e0 = x;
38     children.e1 = y;
39     break;
40   case kTruncF:
41   case kExtF:
42   case kCastFS:
43   case kCastFU:
44   case kCastSF:
45   case kCastUF:
46   case kCastS:
47   case kCastU:
48   case kTruncI:
49   case kBitCast:
50     assert(x != -1u && y == -1u && v);
51     children.e0 = x;
52     children.e1 = y;
53     break;
54   default:
55     assert(x != -1u && y != -1u && !v);
56     children.e0 = x;
57     children.e1 = y;
58     break;
59   }
60 }
61 
62 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
63     : bits(n, false), simple(), exp(e) {
64   bits.set(b);
65 }
66 
67 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
68     : bits(b), simple(), exp(e) {}
69 
70 //===----------------------------------------------------------------------===//
71 // Lattice methods.
72 //===----------------------------------------------------------------------===//
73 
74 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
75   unsigned e = tensorExps.size();
76   tensorExps.push_back(TensorExp(k, e0, e1, v));
77   return e;
78 }
79 
80 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
81   assert(t < numTensors && i < numLoops);
82   unsigned p = latPoints.size();
83   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
84   return p;
85 }
86 
87 unsigned Merger::addSet() {
88   unsigned s = latSets.size();
89   latSets.emplace_back(SmallVector<unsigned, 16>());
90   return s;
91 }
92 
93 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
94   unsigned p = latPoints.size();
95   llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
96   nb |= latPoints[p1].bits;
97   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
98   latPoints.push_back(LatPoint(nb, e));
99   return p;
100 }
101 
102 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
103   unsigned s = addSet();
104   for (unsigned p0 : latSets[s0])
105     for (unsigned p1 : latSets[s1])
106       latSets[s].push_back(conjLatPoint(kind, p0, p1));
107   return s;
108 }
109 
110 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
111   unsigned s = takeConj(kind, s0, s1);
112   // Followed by all in s0.
113   for (unsigned p : latSets[s0])
114     latSets[s].push_back(p);
115   // Map binary 0-y to unary -y.
116   if (kind == kSubF)
117     s1 = mapSet(kNegF, s1);
118   else if (kind == kSubI)
119     s1 = mapSet(kNegI, s1);
120   // Followed by all in s1.
121   for (unsigned p : latSets[s1])
122     latSets[s].push_back(p);
123   return s;
124 }
125 
126 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) {
127   assert(kAbsF <= kind && kind <= kBitCast);
128   unsigned s = addSet();
129   for (unsigned p : latSets[s0]) {
130     unsigned e = addExp(kind, latPoints[p].exp, v);
131     latPoints.push_back(LatPoint(latPoints[p].bits, e));
132     latSets[s].push_back(latPoints.size() - 1);
133   }
134   return s;
135 }
136 
137 unsigned Merger::optimizeSet(unsigned s0) {
138   unsigned s = addSet();
139   assert(latSets[s0].size() != 0);
140   unsigned p0 = latSets[s0][0];
141   for (unsigned p1 : latSets[s0]) {
142     bool add = true;
143     if (p0 != p1) {
144       // Is this a straightforward copy?
145       unsigned e = latPoints[p1].exp;
146       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
147         continue;
148       // Conjunction already covered?
149       for (unsigned p2 : latSets[s]) {
150         assert(!latGT(p1, p2)); // Lj => Li would be bad
151         if (onlyDenseDiff(p2, p1)) {
152           add = false;
153           break;
154         }
155       }
156       assert(!add || latGT(p0, p1));
157     }
158     if (add)
159       latSets[s].push_back(p1);
160   }
161   for (unsigned p : latSets[s])
162     latPoints[p].simple = simplifyCond(s, p);
163   return s;
164 }
165 
166 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
167   // First determine if this lattice point is a *singleton*, i.e.,
168   // the last point in a lattice, no other is less than this one.
169   bool isSingleton = true;
170   for (unsigned p1 : latSets[s0]) {
171     if (p0 != p1 && latGT(p0, p1)) {
172       isSingleton = false;
173       break;
174     }
175   }
176   // Now apply the two basic rules.
177   llvm::BitVector simple = latPoints[p0].bits;
178   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
179   for (unsigned b = 0, be = simple.size(); b < be; b++) {
180     if (simple[b] && !isDim(b, kSparse)) {
181       if (reset)
182         simple.reset(b);
183       reset = true;
184     }
185   }
186   return simple;
187 }
188 
189 bool Merger::latGT(unsigned i, unsigned j) const {
190   const llvm::BitVector &bitsi = latPoints[i].bits;
191   const llvm::BitVector &bitsj = latPoints[j].bits;
192   assert(bitsi.size() == bitsj.size());
193   if (bitsi.count() > bitsj.count()) {
194     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
195       if (bitsj[b] && !bitsi[b])
196         return false;
197     return true;
198   }
199   return false;
200 }
201 
202 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
203   llvm::BitVector tmp = latPoints[j].bits;
204   tmp ^= latPoints[i].bits;
205   return !hasAnyDimOf(tmp, kSparse);
206 }
207 
208 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
209   for (unsigned b = 0, be = bits.size(); b < be; b++)
210     if (bits[b] && isDim(b, d))
211       return true;
212   return false;
213 }
214 
215 bool Merger::isConjunction(unsigned t, unsigned e) const {
216   switch (tensorExps[e].kind) {
217   case kTensor:
218     return tensorExps[e].tensor == t;
219   case kAbsF:
220   case kCeilF:
221   case kFloorF:
222   case kNegF:
223   case kNegI:
224   case kTruncF:
225   case kExtF:
226   case kCastFS:
227   case kCastFU:
228   case kCastSF:
229   case kCastUF:
230   case kCastS:
231   case kCastU:
232   case kTruncI:
233   case kBitCast:
234     return isConjunction(t, tensorExps[e].children.e0);
235   case kDivF: // note: x / c only
236   case kDivS:
237   case kDivU:
238     assert(!maybeZero(tensorExps[e].children.e1));
239     return isConjunction(t, tensorExps[e].children.e0);
240   case kShrS: // note: x >> inv only
241   case kShrU:
242   case kShlI:
243     assert(isInvariant(tensorExps[e].children.e1));
244     return isConjunction(t, tensorExps[e].children.e0);
245   case kMulF:
246   case kMulI:
247   case kAndI:
248     return isConjunction(t, tensorExps[e].children.e0) ||
249            isConjunction(t, tensorExps[e].children.e1);
250   default:
251     return false;
252   }
253 }
254 
255 #ifndef NDEBUG
256 
257 //===----------------------------------------------------------------------===//
258 // Print methods (for debugging).
259 //===----------------------------------------------------------------------===//
260 
261 static const char *kindToOpSymbol(Kind kind) {
262   switch (kind) {
263   case kTensor:
264     return "tensor";
265   case kInvariant:
266     return "invariant";
267   case kAbsF:
268     return "abs";
269   case kCeilF:
270     return "ceil";
271   case kFloorF:
272     return "floor";
273   case kNegF:
274     return "-";
275   case kNegI:
276     return "-";
277   case kTruncF:
278   case kExtF:
279   case kCastFS:
280   case kCastFU:
281   case kCastSF:
282   case kCastUF:
283   case kCastS:
284   case kCastU:
285   case kTruncI:
286   case kBitCast:
287     return "cast";
288   case kMulF:
289     return "*";
290   case kMulI:
291     return "*";
292   case kDivF:
293     return "/";
294   case kDivS:
295     return "/";
296   case kDivU:
297     return "/";
298   case kAddF:
299     return "+";
300   case kAddI:
301     return "+";
302   case kSubF:
303     return "-";
304   case kSubI:
305     return "-";
306   case kAndI:
307     return "&";
308   case kOrI:
309     return "|";
310   case kXorI:
311     return "^";
312   case kShrS:
313     return "a>>";
314   case kShrU:
315     return ">>";
316   case kShlI:
317     return "<<";
318   }
319   llvm_unreachable("unexpected kind for symbol");
320 }
321 
322 void Merger::dumpExp(unsigned e) const {
323   switch (tensorExps[e].kind) {
324   case kTensor:
325     if (tensorExps[e].tensor == syntheticTensor)
326       llvm::dbgs() << "synthetic_";
327     else if (tensorExps[e].tensor == outTensor)
328       llvm::dbgs() << "output_";
329     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
330     break;
331   case kInvariant:
332     llvm::dbgs() << "invariant";
333     break;
334   case kAbsF:
335   case kCeilF:
336   case kFloorF:
337   case kNegF:
338   case kNegI:
339   case kTruncF:
340   case kExtF:
341   case kCastFS:
342   case kCastFU:
343   case kCastSF:
344   case kCastUF:
345   case kCastS:
346   case kCastU:
347   case kTruncI:
348   case kBitCast:
349     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
350     dumpExp(tensorExps[e].children.e0);
351     break;
352   default:
353     llvm::dbgs() << "(";
354     dumpExp(tensorExps[e].children.e0);
355     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
356     dumpExp(tensorExps[e].children.e1);
357     llvm::dbgs() << ")";
358   }
359 }
360 
361 void Merger::dumpLat(unsigned p) const {
362   llvm::dbgs() << "lat(";
363   dumpBits(latPoints[p].bits);
364   llvm::dbgs() << " :";
365   dumpBits(latPoints[p].simple);
366   llvm::dbgs() << " : ";
367   dumpExp(latPoints[p].exp);
368   llvm::dbgs() << " )\n";
369 }
370 
371 void Merger::dumpSet(unsigned s) const {
372   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
373   for (unsigned p : latSets[s]) {
374     llvm::dbgs() << "  ";
375     dumpLat(p);
376   }
377   llvm::dbgs() << "}\n";
378 }
379 
380 void Merger::dumpBits(const llvm::BitVector &bits) const {
381   for (unsigned b = 0, be = bits.size(); b < be; b++) {
382     if (bits[b]) {
383       unsigned t = tensor(b);
384       unsigned i = index(b);
385       llvm::dbgs() << " i_" << t << "_" << i << "_";
386       switch (dims[t][i]) {
387       case kSparse:
388         llvm::dbgs() << "S";
389         break;
390       case kDense:
391         llvm::dbgs() << "D";
392         break;
393       case kSingle:
394         llvm::dbgs() << "T";
395         break;
396       case kUndef:
397         llvm::dbgs() << "U";
398         break;
399       }
400     }
401   }
402 }
403 
404 #endif // NDEBUG
405 
406 //===----------------------------------------------------------------------===//
407 // Builder methods.
408 //===----------------------------------------------------------------------===//
409 
410 unsigned Merger::buildLattices(unsigned e, unsigned i) {
411   Kind kind = tensorExps[e].kind;
412   switch (kind) {
413   case kTensor:
414   case kInvariant: {
415     // Either the index is really used in the tensor expression, or it is
416     // set to the undefined index in that dimension. An invariant expression
417     // is set to a synthetic tensor with undefined indices only.
418     unsigned s = addSet();
419     unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
420     latSets[s].push_back(addLat(t, i, e));
421     return s;
422   }
423   case kAbsF:
424   case kCeilF:
425   case kFloorF:
426   case kNegF:
427   case kNegI:
428   case kTruncF:
429   case kExtF:
430   case kCastFS:
431   case kCastFU:
432   case kCastSF:
433   case kCastUF:
434   case kCastS:
435   case kCastU:
436   case kTruncI:
437   case kBitCast:
438     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
439     // lattice set of the operand through the operator into a new set.
440     //
441     //  -y|!y | y |
442     //  --+---+---+
443     //    | 0 |-y |
444     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
445                   tensorExps[e].val);
446   case kMulF:
447   case kMulI:
448   case kAndI:
449     // A multiplicative operation only needs to be performed
450     // for the conjunction of sparse iteration spaces.
451     //
452     //  x*y|!y | y |
453     //  ---+---+---+
454     //  !x | 0 | 0 |
455     //   x | 0 |x*y|
456     return takeConj(kind, // take binary conjunction
457                     buildLattices(tensorExps[e].children.e0, i),
458                     buildLattices(tensorExps[e].children.e1, i));
459   case kDivF:
460   case kDivS:
461   case kDivU:
462     // A division is tricky, since 0/0, 0/c, c/0 all have
463     // specific outcomes for floating-point and integers.
464     // Thus, we need to traverse the full iteration space.
465     //
466     //  x/y|!y | y |
467     //  ---+---+---+
468     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
469     //   x |x/0|x/y|  INT: x/0=exception for any x
470     //
471     // TODO: for now we "fixed" this by only accepting x/c cases
472     //       during expression building, so that the conjunction
473     //       rules applies (viz. x/c = x*(1/c) as far as lattice
474     //       construction is concerned).
475     assert(!maybeZero(tensorExps[e].children.e1));
476     return takeConj(kind, // take binary conjunction
477                     buildLattices(tensorExps[e].children.e0, i),
478                     buildLattices(tensorExps[e].children.e1, i));
479   case kAddF:
480   case kAddI:
481   case kSubF:
482   case kSubI:
483   case kOrI:
484   case kXorI:
485     // An additive operation needs to be performed
486     // for the disjunction of sparse iteration spaces.
487     //
488     //  x+y|!y | y |    x-y|!y | y |
489     //  ---+---+---+    ---+---+---+
490     //  !x | 0 | y |    !x | 0 |-y |
491     //   x | x |x+y|     x | x |x-y|
492     //
493     // TODO: remove this zero "folding" in favor of external pass into linalg
494     //
495     if (isZero(tensorExps[e].children.e1))
496       return buildLattices(tensorExps[e].children.e0, i);
497     return takeDisj(kind, // take binary disjunction
498                     buildLattices(tensorExps[e].children.e0, i),
499                     buildLattices(tensorExps[e].children.e1, i));
500   case kShrS:
501   case kShrU:
502   case kShlI:
503     // A shift operation by an invariant amount (viz. tensor expressions
504     // can only occur at the left-hand-side of the operator) can be handled
505     // with the conjuction rule.
506     assert(isInvariant(tensorExps[e].children.e1));
507     return takeConj(kind, // take binary conjunction
508                     buildLattices(tensorExps[e].children.e0, i),
509                     buildLattices(tensorExps[e].children.e1, i));
510   }
511   llvm_unreachable("unexpected expression kind");
512 }
513 
514 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
515   Operation *yield = op.region().front().getTerminator();
516   return buildTensorExp(op, yield->getOperand(0));
517 }
518 
519 /// Only returns true if we are certain this is a zero.
520 bool Merger::isZero(unsigned e) const {
521   if (tensorExps[e].kind == kInvariant) {
522     if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
523       return c.getValue() == 0;
524     if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>())
525       return c.getValue().isZero();
526   }
527   return false;
528 }
529 
530 /// Only returns false if we are certain this is a nonzero.
531 bool Merger::maybeZero(unsigned e) const {
532   if (tensorExps[e].kind == kInvariant) {
533     if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
534       return c.getValue() == 0;
535     if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>())
536       return c.getValue().isZero();
537   }
538   return true;
539 }
540 
541 bool Merger::isInvariant(unsigned e) const {
542   return tensorExps[e].kind == kInvariant;
543 }
544 
545 Type Merger::inferType(unsigned e, Value src) {
546   // Obtain the destination type from the cast node.
547   Type dtp = tensorExps[e].val.getType();
548   // Inspect source type. For vector types, apply the same
549   // vectorization to the destination type.
550   if (auto vtp = src.getType().dyn_cast<VectorType>())
551     return VectorType::get(vtp.getNumElements(), dtp);
552   return dtp;
553 }
554 
555 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
556   if (auto arg = v.dyn_cast<BlockArgument>()) {
557     unsigned argN = arg.getArgNumber();
558     // Any argument of the generic op that is not marked as a scalar
559     // argument is considered a tensor, indexed by the implicit loop
560     // bounds. This includes rank-0 tensor arguments.
561     if (arg.getOwner()->getParentOp() == op) {
562       OpOperand *t = op.getInputAndOutputOperands()[argN];
563       if (!op.isScalar(t))
564         return addExp(kTensor, argN);
565       v = t->get(); // get scalar value
566     }
567     // Any other argument (marked as scalar argument for the generic op
568     // or belonging to an enveloping op) is considered invariant.
569     return addExp(kInvariant, v);
570   }
571   // Something defined outside is invariant.
572   Operation *def = v.getDefiningOp();
573   if (def->getBlock() != &op.region().front())
574     return addExp(kInvariant, v);
575   // Construct unary operations if subexpression can be built.
576   if (def->getNumOperands() == 1) {
577     auto x = buildTensorExp(op, def->getOperand(0));
578     if (x.hasValue()) {
579       unsigned e = x.getValue();
580       if (isa<AbsFOp>(def))
581         return addExp(kAbsF, e);
582       if (isa<CeilFOp>(def))
583         return addExp(kCeilF, e);
584       if (isa<FloorFOp>(def))
585         return addExp(kFloorF, e);
586       if (isa<NegFOp>(def))
587         return addExp(kNegF, e); // TODO: no negi in std?
588       if (isa<FPTruncOp>(def))
589         return addExp(kTruncF, e, v);
590       if (isa<FPExtOp>(def))
591         return addExp(kExtF, e, v);
592       if (isa<FPToSIOp>(def))
593         return addExp(kCastFS, e, v);
594       if (isa<FPToUIOp>(def))
595         return addExp(kCastFU, e, v);
596       if (isa<SIToFPOp>(def))
597         return addExp(kCastSF, e, v);
598       if (isa<UIToFPOp>(def))
599         return addExp(kCastUF, e, v);
600       if (isa<SignExtendIOp>(def))
601         return addExp(kCastS, e, v);
602       if (isa<ZeroExtendIOp>(def))
603         return addExp(kCastU, e, v);
604       if (isa<TruncateIOp>(def))
605         return addExp(kTruncI, e, v);
606       if (isa<BitcastOp>(def))
607         return addExp(kBitCast, e, v);
608     }
609   }
610   // Construct binary operations if subexpressions can be built.
611   // TODO: see buildLattices() for an explanation of rejecting
612   //       certain division and shift operations
613   if (def->getNumOperands() == 2) {
614     auto x = buildTensorExp(op, def->getOperand(0));
615     auto y = buildTensorExp(op, def->getOperand(1));
616     if (x.hasValue() && y.hasValue()) {
617       unsigned e0 = x.getValue();
618       unsigned e1 = y.getValue();
619       if (isa<MulFOp>(def))
620         return addExp(kMulF, e0, e1);
621       if (isa<MulIOp>(def))
622         return addExp(kMulI, e0, e1);
623       if (isa<DivFOp>(def) && !maybeZero(e1))
624         return addExp(kDivF, e0, e1);
625       if (isa<SignedDivIOp>(def) && !maybeZero(e1))
626         return addExp(kDivS, e0, e1);
627       if (isa<UnsignedDivIOp>(def) && !maybeZero(e1))
628         return addExp(kDivU, e0, e1);
629       if (isa<AddFOp>(def))
630         return addExp(kAddF, e0, e1);
631       if (isa<AddIOp>(def))
632         return addExp(kAddI, e0, e1);
633       if (isa<SubFOp>(def))
634         return addExp(kSubF, e0, e1);
635       if (isa<SubIOp>(def))
636         return addExp(kSubI, e0, e1);
637       if (isa<AndOp>(def))
638         return addExp(kAndI, e0, e1);
639       if (isa<OrOp>(def))
640         return addExp(kOrI, e0, e1);
641       if (isa<XOrOp>(def))
642         return addExp(kXorI, e0, e1);
643       if (isa<SignedShiftRightOp>(def) && isInvariant(e1))
644         return addExp(kShrS, e0, e1);
645       if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1))
646         return addExp(kShrU, e0, e1);
647       if (isa<ShiftLeftOp>(def) && isInvariant(e1))
648         return addExp(kShlI, e0, e1);
649     }
650   }
651   // Cannot build.
652   return None;
653 }
654 
655 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
656                        Value v0, Value v1) {
657   switch (tensorExps[e].kind) {
658   case kTensor:
659   case kInvariant:
660     llvm_unreachable("unexpected non-op");
661   // Unary ops.
662   case kAbsF:
663     return rewriter.create<AbsFOp>(loc, v0);
664   case kCeilF:
665     return rewriter.create<CeilFOp>(loc, v0);
666   case kFloorF:
667     return rewriter.create<FloorFOp>(loc, v0);
668   case kNegF:
669     return rewriter.create<NegFOp>(loc, v0);
670   case kNegI:
671     assert(v1); // no negi in std
672     return rewriter.create<SubIOp>(loc, v0, v1);
673   case kTruncF:
674     return rewriter.create<FPTruncOp>(loc, v0, inferType(e, v0));
675   case kExtF:
676     return rewriter.create<FPExtOp>(loc, v0, inferType(e, v0));
677   case kCastFS:
678     return rewriter.create<FPToSIOp>(loc, v0, inferType(e, v0));
679   case kCastFU:
680     return rewriter.create<FPToUIOp>(loc, v0, inferType(e, v0));
681   case kCastSF:
682     return rewriter.create<SIToFPOp>(loc, v0, inferType(e, v0));
683   case kCastUF:
684     return rewriter.create<UIToFPOp>(loc, v0, inferType(e, v0));
685   case kCastS:
686     return rewriter.create<SignExtendIOp>(loc, v0, inferType(e, v0));
687   case kCastU:
688     return rewriter.create<ZeroExtendIOp>(loc, v0, inferType(e, v0));
689   case kTruncI:
690     return rewriter.create<TruncateIOp>(loc, v0, inferType(e, v0));
691   case kBitCast:
692     return rewriter.create<BitcastOp>(loc, v0, inferType(e, v0));
693   // Binary ops.
694   case kMulF:
695     return rewriter.create<MulFOp>(loc, v0, v1);
696   case kMulI:
697     return rewriter.create<MulIOp>(loc, v0, v1);
698   case kDivF:
699     return rewriter.create<DivFOp>(loc, v0, v1);
700   case kDivS:
701     return rewriter.create<SignedDivIOp>(loc, v0, v1);
702   case kDivU:
703     return rewriter.create<UnsignedDivIOp>(loc, v0, v1);
704   case kAddF:
705     return rewriter.create<AddFOp>(loc, v0, v1);
706   case kAddI:
707     return rewriter.create<AddIOp>(loc, v0, v1);
708   case kSubF:
709     return rewriter.create<SubFOp>(loc, v0, v1);
710   case kSubI:
711     return rewriter.create<SubIOp>(loc, v0, v1);
712   case kAndI:
713     return rewriter.create<AndOp>(loc, v0, v1);
714   case kOrI:
715     return rewriter.create<OrOp>(loc, v0, v1);
716   case kXorI:
717     return rewriter.create<XOrOp>(loc, v0, v1);
718   case kShrS:
719     return rewriter.create<SignedShiftRightOp>(loc, v0, v1);
720   case kShrU:
721     return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1);
722   case kShlI:
723     return rewriter.create<ShiftLeftOp>(loc, v0, v1);
724   }
725   llvm_unreachable("unexpected expression kind in build");
726 }
727 
728 } // namespace sparse_tensor
729 } // namespace mlir
730