1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 
12 #include "mlir/IR/Operation.h"
13 #include "llvm/Support/Debug.h"
14 
15 namespace mlir {
16 namespace sparse_tensor {
17 
18 //===----------------------------------------------------------------------===//
19 // Constructors.
20 //===----------------------------------------------------------------------===//
21 
22 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
23     : kind(k), val(v) {
24   switch (kind) {
25   case kTensor:
26     assert(x != -1u && y == -1u && !v);
27     tensor = x;
28     break;
29   case kInvariant:
30     assert(x == -1u && y == -1u && v);
31     break;
32   case kIndex:
33     assert(x != -1u && y == -1u && !v);
34     index = x;
35     break;
36   case kAbsF:
37   case kCeilF:
38   case kFloorF:
39   case kNegF:
40   case kNegI:
41     assert(x != -1u && y == -1u && !v);
42     children.e0 = x;
43     children.e1 = y;
44     break;
45   case kTruncF:
46   case kExtF:
47   case kCastFS:
48   case kCastFU:
49   case kCastSF:
50   case kCastUF:
51   case kCastS:
52   case kCastU:
53   case kCastIdx:
54   case kTruncI:
55   case kBitCast:
56     assert(x != -1u && y == -1u && v);
57     children.e0 = x;
58     children.e1 = y;
59     break;
60   default:
61     assert(x != -1u && y != -1u && !v);
62     children.e0 = x;
63     children.e1 = y;
64     break;
65   }
66 }
67 
68 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
69     : bits(n, false), simple(), exp(e) {
70   bits.set(b);
71 }
72 
73 LatPoint::LatPoint(const BitVector &b, unsigned e)
74     : bits(b), simple(), exp(e) {}
75 
76 //===----------------------------------------------------------------------===//
77 // Lattice methods.
78 //===----------------------------------------------------------------------===//
79 
80 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
81   unsigned e = tensorExps.size();
82   tensorExps.push_back(TensorExp(k, e0, e1, v));
83   return e;
84 }
85 
86 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
87   assert(t < numTensors && i < numLoops);
88   unsigned p = latPoints.size();
89   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
90   return p;
91 }
92 
93 unsigned Merger::addSet() {
94   unsigned s = latSets.size();
95   latSets.emplace_back(SmallVector<unsigned, 16>());
96   return s;
97 }
98 
99 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
100   unsigned p = latPoints.size();
101   BitVector nb = BitVector(latPoints[p0].bits);
102   nb |= latPoints[p1].bits;
103   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
104   latPoints.push_back(LatPoint(nb, e));
105   return p;
106 }
107 
108 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
109   unsigned s = addSet();
110   for (unsigned p0 : latSets[s0])
111     for (unsigned p1 : latSets[s1])
112       latSets[s].push_back(conjLatPoint(kind, p0, p1));
113   return s;
114 }
115 
116 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
117   unsigned s = takeConj(kind, s0, s1);
118   // Followed by all in s0.
119   for (unsigned p : latSets[s0])
120     latSets[s].push_back(p);
121   // Map binary 0-y to unary -y.
122   if (kind == kSubF)
123     s1 = mapSet(kNegF, s1);
124   else if (kind == kSubI)
125     s1 = mapSet(kNegI, s1);
126   // Followed by all in s1.
127   for (unsigned p : latSets[s1])
128     latSets[s].push_back(p);
129   return s;
130 }
131 
132 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) {
133   assert(kAbsF <= kind && kind <= kBitCast);
134   unsigned s = addSet();
135   for (unsigned p : latSets[s0]) {
136     unsigned e = addExp(kind, latPoints[p].exp, v);
137     latPoints.push_back(LatPoint(latPoints[p].bits, e));
138     latSets[s].push_back(latPoints.size() - 1);
139   }
140   return s;
141 }
142 
143 unsigned Merger::optimizeSet(unsigned s0) {
144   unsigned s = addSet();
145   assert(!latSets[s0].empty());
146   unsigned p0 = latSets[s0][0];
147   for (unsigned p1 : latSets[s0]) {
148     bool add = true;
149     if (p0 != p1) {
150       // Is this a straightforward copy?
151       unsigned e = latPoints[p1].exp;
152       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
153         continue;
154       // Conjunction already covered?
155       for (unsigned p2 : latSets[s]) {
156         assert(!latGT(p1, p2)); // Lj => Li would be bad
157         if (onlyDenseDiff(p2, p1)) {
158           add = false;
159           break;
160         }
161       }
162       assert(!add || latGT(p0, p1));
163     }
164     if (add)
165       latSets[s].push_back(p1);
166   }
167   for (unsigned p : latSets[s])
168     latPoints[p].simple = simplifyCond(s, p);
169   return s;
170 }
171 
172 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
173   // First determine if this lattice point is a *singleton*, i.e.,
174   // the last point in a lattice, no other is less than this one.
175   bool isSingleton = true;
176   for (unsigned p1 : latSets[s0]) {
177     if (p0 != p1 && latGT(p0, p1)) {
178       isSingleton = false;
179       break;
180     }
181   }
182   // Now apply the two basic rules.
183   BitVector simple = latPoints[p0].bits;
184   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
185   for (unsigned b = 0, be = simple.size(); b < be; b++) {
186     if (simple[b] && !isDim(b, kSparse)) {
187       if (reset)
188         simple.reset(b);
189       reset = true;
190     }
191   }
192   return simple;
193 }
194 
195 bool Merger::latGT(unsigned i, unsigned j) const {
196   const BitVector &bitsi = latPoints[i].bits;
197   const BitVector &bitsj = latPoints[j].bits;
198   assert(bitsi.size() == bitsj.size());
199   if (bitsi.count() > bitsj.count()) {
200     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
201       if (bitsj[b] && !bitsi[b])
202         return false;
203     return true;
204   }
205   return false;
206 }
207 
208 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
209   BitVector tmp = latPoints[j].bits;
210   tmp ^= latPoints[i].bits;
211   return !hasAnyDimOf(tmp, kSparse);
212 }
213 
214 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
215   for (unsigned b = 0, be = bits.size(); b < be; b++)
216     if (bits[b] && isDim(b, d))
217       return true;
218   return false;
219 }
220 
221 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
222   switch (tensorExps[e].kind) {
223   case kTensor:
224     return tensorExps[e].tensor == t;
225   case kAbsF:
226   case kCeilF:
227   case kFloorF:
228   case kNegF:
229   case kNegI:
230   case kTruncF:
231   case kExtF:
232   case kCastFS:
233   case kCastFU:
234   case kCastSF:
235   case kCastUF:
236   case kCastS:
237   case kCastU:
238   case kCastIdx:
239   case kTruncI:
240   case kBitCast:
241     return isSingleCondition(t, tensorExps[e].children.e0);
242   case kDivF: // note: x / c only
243   case kDivS:
244   case kDivU:
245     assert(!maybeZero(tensorExps[e].children.e1));
246     return isSingleCondition(t, tensorExps[e].children.e0);
247   case kShrS: // note: x >> inv only
248   case kShrU:
249   case kShlI:
250     assert(isInvariant(tensorExps[e].children.e1));
251     return isSingleCondition(t, tensorExps[e].children.e0);
252   case kMulF:
253   case kMulI:
254   case kAndI:
255     if (isSingleCondition(t, tensorExps[e].children.e0))
256       return isSingleCondition(t, tensorExps[e].children.e1) ||
257              isInvariant(tensorExps[e].children.e1);
258     if (isSingleCondition(t, tensorExps[e].children.e1))
259       return isInvariant(tensorExps[e].children.e0);
260     return false;
261   case kAddF:
262   case kAddI:
263     return isSingleCondition(t, tensorExps[e].children.e0) &&
264            isSingleCondition(t, tensorExps[e].children.e1);
265   default:
266     return false;
267   }
268 }
269 
270 #ifndef NDEBUG
271 
272 //===----------------------------------------------------------------------===//
273 // Print methods (for debugging).
274 //===----------------------------------------------------------------------===//
275 
276 static const char *kindToOpSymbol(Kind kind) {
277   switch (kind) {
278   case kTensor:
279     return "tensor";
280   case kInvariant:
281     return "invariant";
282   case kIndex:
283     return "index";
284   case kAbsF:
285     return "abs";
286   case kCeilF:
287     return "ceil";
288   case kFloorF:
289     return "floor";
290   case kNegF:
291     return "-";
292   case kNegI:
293     return "-";
294   case kTruncF:
295   case kExtF:
296   case kCastFS:
297   case kCastFU:
298   case kCastSF:
299   case kCastUF:
300   case kCastS:
301   case kCastU:
302   case kCastIdx:
303   case kTruncI:
304   case kBitCast:
305     return "cast";
306   case kMulF:
307     return "*";
308   case kMulI:
309     return "*";
310   case kDivF:
311     return "/";
312   case kDivS:
313     return "/";
314   case kDivU:
315     return "/";
316   case kAddF:
317     return "+";
318   case kAddI:
319     return "+";
320   case kSubF:
321     return "-";
322   case kSubI:
323     return "-";
324   case kAndI:
325     return "&";
326   case kOrI:
327     return "|";
328   case kXorI:
329     return "^";
330   case kShrS:
331     return "a>>";
332   case kShrU:
333     return ">>";
334   case kShlI:
335     return "<<";
336   }
337   llvm_unreachable("unexpected kind for symbol");
338 }
339 
340 void Merger::dumpExp(unsigned e) const {
341   switch (tensorExps[e].kind) {
342   case kTensor:
343     if (tensorExps[e].tensor == syntheticTensor)
344       llvm::dbgs() << "synthetic_";
345     else if (tensorExps[e].tensor == outTensor)
346       llvm::dbgs() << "output_";
347     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
348     break;
349   case kInvariant:
350     llvm::dbgs() << "invariant";
351     break;
352   case kIndex:
353     llvm::dbgs() << "index_" << tensorExps[e].index;
354     break;
355   case kAbsF:
356   case kCeilF:
357   case kFloorF:
358   case kNegF:
359   case kNegI:
360   case kTruncF:
361   case kExtF:
362   case kCastFS:
363   case kCastFU:
364   case kCastSF:
365   case kCastUF:
366   case kCastS:
367   case kCastU:
368   case kCastIdx:
369   case kTruncI:
370   case kBitCast:
371     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
372     dumpExp(tensorExps[e].children.e0);
373     break;
374   default:
375     llvm::dbgs() << "(";
376     dumpExp(tensorExps[e].children.e0);
377     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
378     dumpExp(tensorExps[e].children.e1);
379     llvm::dbgs() << ")";
380   }
381 }
382 
383 void Merger::dumpLat(unsigned p) const {
384   llvm::dbgs() << "lat(";
385   dumpBits(latPoints[p].bits);
386   llvm::dbgs() << " :";
387   dumpBits(latPoints[p].simple);
388   llvm::dbgs() << " : ";
389   dumpExp(latPoints[p].exp);
390   llvm::dbgs() << " )\n";
391 }
392 
393 void Merger::dumpSet(unsigned s) const {
394   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
395   for (unsigned p : latSets[s]) {
396     llvm::dbgs() << "  ";
397     dumpLat(p);
398   }
399   llvm::dbgs() << "}\n";
400 }
401 
402 void Merger::dumpBits(const BitVector &bits) const {
403   for (unsigned b = 0, be = bits.size(); b < be; b++) {
404     if (bits[b]) {
405       unsigned t = tensor(b);
406       unsigned i = index(b);
407       llvm::dbgs() << " i_" << t << "_" << i << "_";
408       switch (dims[t][i]) {
409       case kSparse:
410         llvm::dbgs() << "S";
411         break;
412       case kDense:
413         llvm::dbgs() << "D";
414         break;
415       case kSingle:
416         llvm::dbgs() << "T";
417         break;
418       case kUndef:
419         llvm::dbgs() << "U";
420         break;
421       }
422     }
423   }
424 }
425 
426 #endif // NDEBUG
427 
428 //===----------------------------------------------------------------------===//
429 // Builder methods.
430 //===----------------------------------------------------------------------===//
431 
432 unsigned Merger::buildLattices(unsigned e, unsigned i) {
433   Kind kind = tensorExps[e].kind;
434   switch (kind) {
435   case kTensor:
436   case kInvariant:
437   case kIndex: {
438     // Either the index is really used in the tensor expression, or it is
439     // set to the undefined index in that dimension. An invariant expression,
440     // a proper index value, and a truly dynamic sparse output tensor are set
441     // to a synthetic tensor with undefined indices only to ensure the
442     // iteration space is not skipped as a result of their contents.
443     unsigned s = addSet();
444     unsigned t = syntheticTensor;
445     if (kind == kTensor) {
446       t = tensorExps[e].tensor;
447       if (hasSparseOut && t == outTensor)
448         t = syntheticTensor;
449     }
450     latSets[s].push_back(addLat(t, i, e));
451     return s;
452   }
453   case kAbsF:
454   case kCeilF:
455   case kFloorF:
456   case kNegF:
457   case kNegI:
458   case kTruncF:
459   case kExtF:
460   case kCastFS:
461   case kCastFU:
462   case kCastSF:
463   case kCastUF:
464   case kCastS:
465   case kCastU:
466   case kCastIdx:
467   case kTruncI:
468   case kBitCast:
469     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
470     // lattice set of the operand through the operator into a new set.
471     //
472     //  -y|!y | y |
473     //  --+---+---+
474     //    | 0 |-y |
475     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
476                   tensorExps[e].val);
477   case kMulF:
478   case kMulI:
479   case kAndI:
480     // A multiplicative operation only needs to be performed
481     // for the conjunction of sparse iteration spaces.
482     //
483     //  x*y|!y | y |
484     //  ---+---+---+
485     //  !x | 0 | 0 |
486     //   x | 0 |x*y|
487     return takeConj(kind, // take binary conjunction
488                     buildLattices(tensorExps[e].children.e0, i),
489                     buildLattices(tensorExps[e].children.e1, i));
490   case kDivF:
491   case kDivS:
492   case kDivU:
493     // A division is tricky, since 0/0, 0/c, c/0 all have
494     // specific outcomes for floating-point and integers.
495     // Thus, we need to traverse the full iteration space.
496     //
497     //  x/y|!y | y |
498     //  ---+---+---+
499     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
500     //   x |x/0|x/y|  INT: x/0=exception for any x
501     //
502     // TODO: for now we "fixed" this by only accepting x/c cases
503     //       during expression building, so that the conjunction
504     //       rules applies (viz. x/c = x*(1/c) as far as lattice
505     //       construction is concerned).
506     assert(!maybeZero(tensorExps[e].children.e1));
507     return takeConj(kind, // take binary conjunction
508                     buildLattices(tensorExps[e].children.e0, i),
509                     buildLattices(tensorExps[e].children.e1, i));
510   case kAddF:
511   case kAddI:
512   case kSubF:
513   case kSubI:
514   case kOrI:
515   case kXorI:
516     // An additive operation needs to be performed
517     // for the disjunction of sparse iteration spaces.
518     //
519     //  x+y|!y | y |    x-y|!y | y |
520     //  ---+---+---+    ---+---+---+
521     //  !x | 0 | y |    !x | 0 |-y |
522     //   x | x |x+y|     x | x |x-y|
523     return takeDisj(kind, // take binary disjunction
524                     buildLattices(tensorExps[e].children.e0, i),
525                     buildLattices(tensorExps[e].children.e1, i));
526   case kShrS:
527   case kShrU:
528   case kShlI:
529     // A shift operation by an invariant amount (viz. tensor expressions
530     // can only occur at the left-hand-side of the operator) can be handled
531     // with the conjuction rule.
532     assert(isInvariant(tensorExps[e].children.e1));
533     return takeConj(kind, // take binary conjunction
534                     buildLattices(tensorExps[e].children.e0, i),
535                     buildLattices(tensorExps[e].children.e1, i));
536   }
537   llvm_unreachable("unexpected expression kind");
538 }
539 
540 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
541   Operation *yield = op.region().front().getTerminator();
542   return buildTensorExp(op, yield->getOperand(0));
543 }
544 
545 /// Only returns false if we are certain this is a nonzero.
546 bool Merger::maybeZero(unsigned e) const {
547   if (tensorExps[e].kind == kInvariant) {
548     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
549       return c.value() == 0;
550     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
551       return c.value().isZero();
552   }
553   return true;
554 }
555 
556 bool Merger::isInvariant(unsigned e) const {
557   return tensorExps[e].kind == kInvariant;
558 }
559 
560 Type Merger::inferType(unsigned e, Value src) {
561   // Obtain the destination type from the cast node.
562   Type dtp = tensorExps[e].val.getType();
563   // Inspect source type. For vector types, apply the same
564   // vectorization to the destination type.
565   if (auto vtp = src.getType().dyn_cast<VectorType>())
566     return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
567   return dtp;
568 }
569 
570 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
571   if (auto arg = v.dyn_cast<BlockArgument>()) {
572     unsigned argN = arg.getArgNumber();
573     // Any argument of the generic op that is not marked as a scalar
574     // argument is considered a tensor, indexed by the implicit loop
575     // bounds. This includes rank-0 tensor arguments.
576     if (arg.getOwner()->getParentOp() == op) {
577       OpOperand *t = op.getInputAndOutputOperands()[argN];
578       if (!op.isScalar(t))
579         return addExp(kTensor, argN);
580       v = t->get(); // get scalar value
581     }
582     // Any other argument (marked as scalar argument for the generic op
583     // or belonging to an enveloping op) is considered invariant.
584     return addExp(kInvariant, v);
585   }
586   // Something defined outside is invariant.
587   Operation *def = v.getDefiningOp();
588   if (def->getBlock() != &op.region().front())
589     return addExp(kInvariant, v);
590   // Construct index operations.
591   if (def->getNumOperands() == 0) {
592     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
593       return addExp(kIndex, indexOp.dim());
594   }
595   // Construct unary operations if subexpression can be built.
596   if (def->getNumOperands() == 1) {
597     auto x = buildTensorExp(op, def->getOperand(0));
598     if (x.hasValue()) {
599       unsigned e = x.getValue();
600       if (isa<math::AbsOp>(def))
601         return addExp(kAbsF, e);
602       if (isa<math::CeilOp>(def))
603         return addExp(kCeilF, e);
604       if (isa<math::FloorOp>(def))
605         return addExp(kFloorF, e);
606       if (isa<arith::NegFOp>(def))
607         return addExp(kNegF, e); // no negi in std
608       if (isa<arith::TruncFOp>(def))
609         return addExp(kTruncF, e, v);
610       if (isa<arith::ExtFOp>(def))
611         return addExp(kExtF, e, v);
612       if (isa<arith::FPToSIOp>(def))
613         return addExp(kCastFS, e, v);
614       if (isa<arith::FPToUIOp>(def))
615         return addExp(kCastFU, e, v);
616       if (isa<arith::SIToFPOp>(def))
617         return addExp(kCastSF, e, v);
618       if (isa<arith::UIToFPOp>(def))
619         return addExp(kCastUF, e, v);
620       if (isa<arith::ExtSIOp>(def))
621         return addExp(kCastS, e, v);
622       if (isa<arith::ExtUIOp>(def))
623         return addExp(kCastU, e, v);
624       if (isa<arith::IndexCastOp>(def))
625         return addExp(kCastIdx, e, v);
626       if (isa<arith::TruncIOp>(def))
627         return addExp(kTruncI, e, v);
628       if (isa<arith::BitcastOp>(def))
629         return addExp(kBitCast, e, v);
630     }
631   }
632   // Construct binary operations if subexpressions can be built.
633   // See buildLattices() for an explanation of rejecting certain
634   // division and shift operations
635   if (def->getNumOperands() == 2) {
636     auto x = buildTensorExp(op, def->getOperand(0));
637     auto y = buildTensorExp(op, def->getOperand(1));
638     if (x.hasValue() && y.hasValue()) {
639       unsigned e0 = x.getValue();
640       unsigned e1 = y.getValue();
641       if (isa<arith::MulFOp>(def))
642         return addExp(kMulF, e0, e1);
643       if (isa<arith::MulIOp>(def))
644         return addExp(kMulI, e0, e1);
645       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
646         return addExp(kDivF, e0, e1);
647       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
648         return addExp(kDivS, e0, e1);
649       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
650         return addExp(kDivU, e0, e1);
651       if (isa<arith::AddFOp>(def))
652         return addExp(kAddF, e0, e1);
653       if (isa<arith::AddIOp>(def))
654         return addExp(kAddI, e0, e1);
655       if (isa<arith::SubFOp>(def))
656         return addExp(kSubF, e0, e1);
657       if (isa<arith::SubIOp>(def))
658         return addExp(kSubI, e0, e1);
659       if (isa<arith::AndIOp>(def))
660         return addExp(kAndI, e0, e1);
661       if (isa<arith::OrIOp>(def))
662         return addExp(kOrI, e0, e1);
663       if (isa<arith::XOrIOp>(def))
664         return addExp(kXorI, e0, e1);
665       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
666         return addExp(kShrS, e0, e1);
667       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
668         return addExp(kShrU, e0, e1);
669       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
670         return addExp(kShlI, e0, e1);
671     }
672   }
673   // Cannot build.
674   return None;
675 }
676 
677 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
678                        Value v0, Value v1) {
679   switch (tensorExps[e].kind) {
680   case kTensor:
681   case kInvariant:
682   case kIndex:
683     llvm_unreachable("unexpected non-op");
684   // Unary ops.
685   case kAbsF:
686     return rewriter.create<math::AbsOp>(loc, v0);
687   case kCeilF:
688     return rewriter.create<math::CeilOp>(loc, v0);
689   case kFloorF:
690     return rewriter.create<math::FloorOp>(loc, v0);
691   case kNegF:
692     return rewriter.create<arith::NegFOp>(loc, v0);
693   case kNegI: // no negi in std
694     return rewriter.create<arith::SubIOp>(
695         loc,
696         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
697                                            rewriter.getZeroAttr(v0.getType())),
698         v0);
699   case kTruncF:
700     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
701   case kExtF:
702     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
703   case kCastFS:
704     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
705   case kCastFU:
706     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
707   case kCastSF:
708     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
709   case kCastUF:
710     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
711   case kCastS:
712     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
713   case kCastU:
714     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
715   case kCastIdx:
716     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
717   case kTruncI:
718     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
719   case kBitCast:
720     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
721   // Binary ops.
722   case kMulF:
723     return rewriter.create<arith::MulFOp>(loc, v0, v1);
724   case kMulI:
725     return rewriter.create<arith::MulIOp>(loc, v0, v1);
726   case kDivF:
727     return rewriter.create<arith::DivFOp>(loc, v0, v1);
728   case kDivS:
729     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
730   case kDivU:
731     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
732   case kAddF:
733     return rewriter.create<arith::AddFOp>(loc, v0, v1);
734   case kAddI:
735     return rewriter.create<arith::AddIOp>(loc, v0, v1);
736   case kSubF:
737     return rewriter.create<arith::SubFOp>(loc, v0, v1);
738   case kSubI:
739     return rewriter.create<arith::SubIOp>(loc, v0, v1);
740   case kAndI:
741     return rewriter.create<arith::AndIOp>(loc, v0, v1);
742   case kOrI:
743     return rewriter.create<arith::OrIOp>(loc, v0, v1);
744   case kXorI:
745     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
746   case kShrS:
747     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
748   case kShrU:
749     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
750   case kShlI:
751     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
752   }
753   llvm_unreachable("unexpected expression kind in build");
754 }
755 
756 } // namespace sparse_tensor
757 } // namespace mlir
758