1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Math/IR/Math.h"
12 
13 #include "mlir/IR/Operation.h"
14 #include "llvm/Support/Debug.h"
15 
16 namespace mlir {
17 namespace sparse_tensor {
18 
19 //===----------------------------------------------------------------------===//
20 // Constructors.
21 //===----------------------------------------------------------------------===//
22 
23 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
24     : kind(k), val(v) {
25   switch (kind) {
26   case kTensor:
27     assert(x != -1u && y == -1u && !v);
28     tensor = x;
29     break;
30   case kInvariant:
31     assert(x == -1u && y == -1u && v);
32     break;
33   case kIndex:
34     assert(x != -1u && y == -1u && !v);
35     index = x;
36     break;
37   case kAbsF:
38   case kCeilF:
39   case kFloorF:
40   case kNegF:
41   case kNegI:
42     assert(x != -1u && y == -1u && !v);
43     children.e0 = x;
44     children.e1 = y;
45     break;
46   case kTruncF:
47   case kExtF:
48   case kCastFS:
49   case kCastFU:
50   case kCastSF:
51   case kCastUF:
52   case kCastS:
53   case kCastU:
54   case kCastIdx:
55   case kTruncI:
56   case kBitCast:
57     assert(x != -1u && y == -1u && v);
58     children.e0 = x;
59     children.e1 = y;
60     break;
61   default:
62     assert(x != -1u && y != -1u && !v);
63     children.e0 = x;
64     children.e1 = y;
65     break;
66   }
67 }
68 
69 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
70     : bits(n, false), simple(), exp(e) {
71   bits.set(b);
72 }
73 
74 LatPoint::LatPoint(const BitVector &b, unsigned e)
75     : bits(b), simple(), exp(e) {}
76 
77 //===----------------------------------------------------------------------===//
78 // Lattice methods.
79 //===----------------------------------------------------------------------===//
80 
81 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
82   unsigned e = tensorExps.size();
83   tensorExps.push_back(TensorExp(k, e0, e1, v));
84   return e;
85 }
86 
87 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
88   assert(t < numTensors && i < numLoops);
89   unsigned p = latPoints.size();
90   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
91   return p;
92 }
93 
94 unsigned Merger::addSet() {
95   unsigned s = latSets.size();
96   latSets.emplace_back(SmallVector<unsigned, 16>());
97   return s;
98 }
99 
100 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
101   unsigned p = latPoints.size();
102   BitVector nb = BitVector(latPoints[p0].bits);
103   nb |= latPoints[p1].bits;
104   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
105   latPoints.push_back(LatPoint(nb, e));
106   return p;
107 }
108 
109 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
110   unsigned s = addSet();
111   for (unsigned p0 : latSets[s0])
112     for (unsigned p1 : latSets[s1])
113       latSets[s].push_back(conjLatPoint(kind, p0, p1));
114   return s;
115 }
116 
117 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
118   unsigned s = takeConj(kind, s0, s1);
119   // Followed by all in s0.
120   for (unsigned p : latSets[s0])
121     latSets[s].push_back(p);
122   // Map binary 0-y to unary -y.
123   if (kind == kSubF)
124     s1 = mapSet(kNegF, s1);
125   else if (kind == kSubI)
126     s1 = mapSet(kNegI, s1);
127   // Followed by all in s1.
128   for (unsigned p : latSets[s1])
129     latSets[s].push_back(p);
130   return s;
131 }
132 
133 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v) {
134   assert(kAbsF <= kind && kind <= kBitCast);
135   unsigned s = addSet();
136   for (unsigned p : latSets[s0]) {
137     unsigned e = addExp(kind, latPoints[p].exp, v);
138     latPoints.push_back(LatPoint(latPoints[p].bits, e));
139     latSets[s].push_back(latPoints.size() - 1);
140   }
141   return s;
142 }
143 
144 unsigned Merger::optimizeSet(unsigned s0) {
145   unsigned s = addSet();
146   assert(!latSets[s0].empty());
147   unsigned p0 = latSets[s0][0];
148   for (unsigned p1 : latSets[s0]) {
149     bool add = true;
150     if (p0 != p1) {
151       // Is this a straightforward copy?
152       unsigned e = latPoints[p1].exp;
153       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
154         continue;
155       // Conjunction already covered?
156       for (unsigned p2 : latSets[s]) {
157         assert(!latGT(p1, p2)); // Lj => Li would be bad
158         if (onlyDenseDiff(p2, p1)) {
159           add = false;
160           break;
161         }
162       }
163       assert(!add || latGT(p0, p1));
164     }
165     if (add)
166       latSets[s].push_back(p1);
167   }
168   for (unsigned p : latSets[s])
169     latPoints[p].simple = simplifyCond(s, p);
170   return s;
171 }
172 
173 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
174   // First determine if this lattice point is a *singleton*, i.e.,
175   // the last point in a lattice, no other is less than this one.
176   bool isSingleton = true;
177   for (unsigned p1 : latSets[s0]) {
178     if (p0 != p1 && latGT(p0, p1)) {
179       isSingleton = false;
180       break;
181     }
182   }
183   // Now apply the two basic rules.
184   BitVector simple = latPoints[p0].bits;
185   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
186   for (unsigned b = 0, be = simple.size(); b < be; b++) {
187     if (simple[b] && !isDim(b, kSparse)) {
188       if (reset)
189         simple.reset(b);
190       reset = true;
191     }
192   }
193   return simple;
194 }
195 
196 bool Merger::latGT(unsigned i, unsigned j) const {
197   const BitVector &bitsi = latPoints[i].bits;
198   const BitVector &bitsj = latPoints[j].bits;
199   assert(bitsi.size() == bitsj.size());
200   if (bitsi.count() > bitsj.count()) {
201     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
202       if (bitsj[b] && !bitsi[b])
203         return false;
204     return true;
205   }
206   return false;
207 }
208 
209 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
210   BitVector tmp = latPoints[j].bits;
211   tmp ^= latPoints[i].bits;
212   return !hasAnyDimOf(tmp, kSparse);
213 }
214 
215 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
216   for (unsigned b = 0, be = bits.size(); b < be; b++)
217     if (bits[b] && isDim(b, d))
218       return true;
219   return false;
220 }
221 
222 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
223   switch (tensorExps[e].kind) {
224   case kTensor:
225     return tensorExps[e].tensor == t;
226   case kAbsF:
227   case kCeilF:
228   case kFloorF:
229   case kNegF:
230   case kNegI:
231   case kTruncF:
232   case kExtF:
233   case kCastFS:
234   case kCastFU:
235   case kCastSF:
236   case kCastUF:
237   case kCastS:
238   case kCastU:
239   case kCastIdx:
240   case kTruncI:
241   case kBitCast:
242     return isSingleCondition(t, tensorExps[e].children.e0);
243   case kDivF: // note: x / c only
244   case kDivS:
245   case kDivU:
246     assert(!maybeZero(tensorExps[e].children.e1));
247     return isSingleCondition(t, tensorExps[e].children.e0);
248   case kShrS: // note: x >> inv only
249   case kShrU:
250   case kShlI:
251     assert(isInvariant(tensorExps[e].children.e1));
252     return isSingleCondition(t, tensorExps[e].children.e0);
253   case kMulF:
254   case kMulI:
255   case kAndI:
256     if (isSingleCondition(t, tensorExps[e].children.e0))
257       return isSingleCondition(t, tensorExps[e].children.e1) ||
258              isInvariant(tensorExps[e].children.e1);
259     if (isSingleCondition(t, tensorExps[e].children.e1))
260       return isInvariant(tensorExps[e].children.e0);
261     return false;
262   case kAddF:
263   case kAddI:
264     return isSingleCondition(t, tensorExps[e].children.e0) &&
265            isSingleCondition(t, tensorExps[e].children.e1);
266   default:
267     return false;
268   }
269 }
270 
271 #ifndef NDEBUG
272 
273 //===----------------------------------------------------------------------===//
274 // Print methods (for debugging).
275 //===----------------------------------------------------------------------===//
276 
277 static const char *kindToOpSymbol(Kind kind) {
278   switch (kind) {
279   case kTensor:
280     return "tensor";
281   case kInvariant:
282     return "invariant";
283   case kIndex:
284     return "index";
285   case kAbsF:
286     return "abs";
287   case kCeilF:
288     return "ceil";
289   case kFloorF:
290     return "floor";
291   case kNegF:
292     return "-";
293   case kNegI:
294     return "-";
295   case kTruncF:
296   case kExtF:
297   case kCastFS:
298   case kCastFU:
299   case kCastSF:
300   case kCastUF:
301   case kCastS:
302   case kCastU:
303   case kCastIdx:
304   case kTruncI:
305   case kBitCast:
306     return "cast";
307   case kMulF:
308     return "*";
309   case kMulI:
310     return "*";
311   case kDivF:
312     return "/";
313   case kDivS:
314     return "/";
315   case kDivU:
316     return "/";
317   case kAddF:
318     return "+";
319   case kAddI:
320     return "+";
321   case kSubF:
322     return "-";
323   case kSubI:
324     return "-";
325   case kAndI:
326     return "&";
327   case kOrI:
328     return "|";
329   case kXorI:
330     return "^";
331   case kShrS:
332     return "a>>";
333   case kShrU:
334     return ">>";
335   case kShlI:
336     return "<<";
337   }
338   llvm_unreachable("unexpected kind for symbol");
339 }
340 
341 void Merger::dumpExp(unsigned e) const {
342   switch (tensorExps[e].kind) {
343   case kTensor:
344     if (tensorExps[e].tensor == syntheticTensor)
345       llvm::dbgs() << "synthetic_";
346     else if (tensorExps[e].tensor == outTensor)
347       llvm::dbgs() << "output_";
348     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
349     break;
350   case kInvariant:
351     llvm::dbgs() << "invariant";
352     break;
353   case kIndex:
354     llvm::dbgs() << "index_" << tensorExps[e].index;
355     break;
356   case kAbsF:
357   case kCeilF:
358   case kFloorF:
359   case kNegF:
360   case kNegI:
361   case kTruncF:
362   case kExtF:
363   case kCastFS:
364   case kCastFU:
365   case kCastSF:
366   case kCastUF:
367   case kCastS:
368   case kCastU:
369   case kCastIdx:
370   case kTruncI:
371   case kBitCast:
372     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
373     dumpExp(tensorExps[e].children.e0);
374     break;
375   default:
376     llvm::dbgs() << "(";
377     dumpExp(tensorExps[e].children.e0);
378     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
379     dumpExp(tensorExps[e].children.e1);
380     llvm::dbgs() << ")";
381   }
382 }
383 
384 void Merger::dumpLat(unsigned p) const {
385   llvm::dbgs() << "lat(";
386   dumpBits(latPoints[p].bits);
387   llvm::dbgs() << " :";
388   dumpBits(latPoints[p].simple);
389   llvm::dbgs() << " : ";
390   dumpExp(latPoints[p].exp);
391   llvm::dbgs() << " )\n";
392 }
393 
394 void Merger::dumpSet(unsigned s) const {
395   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
396   for (unsigned p : latSets[s]) {
397     llvm::dbgs() << "  ";
398     dumpLat(p);
399   }
400   llvm::dbgs() << "}\n";
401 }
402 
403 void Merger::dumpBits(const BitVector &bits) const {
404   for (unsigned b = 0, be = bits.size(); b < be; b++) {
405     if (bits[b]) {
406       unsigned t = tensor(b);
407       unsigned i = index(b);
408       llvm::dbgs() << " i_" << t << "_" << i << "_";
409       switch (dims[t][i]) {
410       case kSparse:
411         llvm::dbgs() << "S";
412         break;
413       case kDense:
414         llvm::dbgs() << "D";
415         break;
416       case kSingle:
417         llvm::dbgs() << "T";
418         break;
419       case kUndef:
420         llvm::dbgs() << "U";
421         break;
422       }
423     }
424   }
425 }
426 
427 #endif // NDEBUG
428 
429 //===----------------------------------------------------------------------===//
430 // Builder methods.
431 //===----------------------------------------------------------------------===//
432 
433 unsigned Merger::buildLattices(unsigned e, unsigned i) {
434   Kind kind = tensorExps[e].kind;
435   switch (kind) {
436   case kTensor:
437   case kInvariant:
438   case kIndex: {
439     // Either the index is really used in the tensor expression, or it is
440     // set to the undefined index in that dimension. An invariant expression,
441     // a proper index value, and a truly dynamic sparse output tensor are set
442     // to a synthetic tensor with undefined indices only to ensure the
443     // iteration space is not skipped as a result of their contents.
444     unsigned s = addSet();
445     unsigned t = syntheticTensor;
446     if (kind == kTensor) {
447       t = tensorExps[e].tensor;
448       if (hasSparseOut && t == outTensor)
449         t = syntheticTensor;
450     }
451     latSets[s].push_back(addLat(t, i, e));
452     return s;
453   }
454   case kAbsF:
455   case kCeilF:
456   case kFloorF:
457   case kNegF:
458   case kNegI:
459   case kTruncF:
460   case kExtF:
461   case kCastFS:
462   case kCastFU:
463   case kCastSF:
464   case kCastUF:
465   case kCastS:
466   case kCastU:
467   case kCastIdx:
468   case kTruncI:
469   case kBitCast:
470     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
471     // lattice set of the operand through the operator into a new set.
472     //
473     //  -y|!y | y |
474     //  --+---+---+
475     //    | 0 |-y |
476     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
477                   tensorExps[e].val);
478   case kMulF:
479   case kMulI:
480   case kAndI:
481     // A multiplicative operation only needs to be performed
482     // for the conjunction of sparse iteration spaces.
483     //
484     //  x*y|!y | y |
485     //  ---+---+---+
486     //  !x | 0 | 0 |
487     //   x | 0 |x*y|
488     return takeConj(kind, // take binary conjunction
489                     buildLattices(tensorExps[e].children.e0, i),
490                     buildLattices(tensorExps[e].children.e1, i));
491   case kDivF:
492   case kDivS:
493   case kDivU:
494     // A division is tricky, since 0/0, 0/c, c/0 all have
495     // specific outcomes for floating-point and integers.
496     // Thus, we need to traverse the full iteration space.
497     //
498     //  x/y|!y | y |
499     //  ---+---+---+
500     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
501     //   x |x/0|x/y|  INT: x/0=exception for any x
502     //
503     // TODO: for now we "fixed" this by only accepting x/c cases
504     //       during expression building, so that the conjunction
505     //       rules applies (viz. x/c = x*(1/c) as far as lattice
506     //       construction is concerned).
507     assert(!maybeZero(tensorExps[e].children.e1));
508     return takeConj(kind, // take binary conjunction
509                     buildLattices(tensorExps[e].children.e0, i),
510                     buildLattices(tensorExps[e].children.e1, i));
511   case kAddF:
512   case kAddI:
513   case kSubF:
514   case kSubI:
515   case kOrI:
516   case kXorI:
517     // An additive operation needs to be performed
518     // for the disjunction of sparse iteration spaces.
519     //
520     //  x+y|!y | y |    x-y|!y | y |
521     //  ---+---+---+    ---+---+---+
522     //  !x | 0 | y |    !x | 0 |-y |
523     //   x | x |x+y|     x | x |x-y|
524     return takeDisj(kind, // take binary disjunction
525                     buildLattices(tensorExps[e].children.e0, i),
526                     buildLattices(tensorExps[e].children.e1, i));
527   case kShrS:
528   case kShrU:
529   case kShlI:
530     // A shift operation by an invariant amount (viz. tensor expressions
531     // can only occur at the left-hand-side of the operator) can be handled
532     // with the conjuction rule.
533     assert(isInvariant(tensorExps[e].children.e1));
534     return takeConj(kind, // take binary conjunction
535                     buildLattices(tensorExps[e].children.e0, i),
536                     buildLattices(tensorExps[e].children.e1, i));
537   }
538   llvm_unreachable("unexpected expression kind");
539 }
540 
541 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
542   Operation *yield = op.region().front().getTerminator();
543   return buildTensorExp(op, yield->getOperand(0));
544 }
545 
546 /// Only returns false if we are certain this is a nonzero.
547 bool Merger::maybeZero(unsigned e) const {
548   if (tensorExps[e].kind == kInvariant) {
549     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
550       return c.value() == 0;
551     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
552       return c.value().isZero();
553   }
554   return true;
555 }
556 
557 bool Merger::isInvariant(unsigned e) const {
558   return tensorExps[e].kind == kInvariant;
559 }
560 
561 Type Merger::inferType(unsigned e, Value src) {
562   // Obtain the destination type from the cast node.
563   Type dtp = tensorExps[e].val.getType();
564   // Inspect source type. For vector types, apply the same
565   // vectorization to the destination type.
566   if (auto vtp = src.getType().dyn_cast<VectorType>())
567     return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
568   return dtp;
569 }
570 
571 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
572   if (auto arg = v.dyn_cast<BlockArgument>()) {
573     unsigned argN = arg.getArgNumber();
574     // Any argument of the generic op that is not marked as a scalar
575     // argument is considered a tensor, indexed by the implicit loop
576     // bounds. This includes rank-0 tensor arguments.
577     if (arg.getOwner()->getParentOp() == op) {
578       OpOperand *t = op.getInputAndOutputOperands()[argN];
579       if (!op.isScalar(t))
580         return addExp(kTensor, argN);
581       v = t->get(); // get scalar value
582     }
583     // Any other argument (marked as scalar argument for the generic op
584     // or belonging to an enveloping op) is considered invariant.
585     return addExp(kInvariant, v);
586   }
587   // Something defined outside is invariant.
588   Operation *def = v.getDefiningOp();
589   if (def->getBlock() != &op.region().front())
590     return addExp(kInvariant, v);
591   // Construct index operations.
592   if (def->getNumOperands() == 0) {
593     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
594       return addExp(kIndex, indexOp.dim());
595   }
596   // Construct unary operations if subexpression can be built.
597   if (def->getNumOperands() == 1) {
598     auto x = buildTensorExp(op, def->getOperand(0));
599     if (x.hasValue()) {
600       unsigned e = x.getValue();
601       if (isa<math::AbsOp>(def))
602         return addExp(kAbsF, e);
603       if (isa<math::CeilOp>(def))
604         return addExp(kCeilF, e);
605       if (isa<math::FloorOp>(def))
606         return addExp(kFloorF, e);
607       if (isa<arith::NegFOp>(def))
608         return addExp(kNegF, e); // no negi in std
609       if (isa<arith::TruncFOp>(def))
610         return addExp(kTruncF, e, v);
611       if (isa<arith::ExtFOp>(def))
612         return addExp(kExtF, e, v);
613       if (isa<arith::FPToSIOp>(def))
614         return addExp(kCastFS, e, v);
615       if (isa<arith::FPToUIOp>(def))
616         return addExp(kCastFU, e, v);
617       if (isa<arith::SIToFPOp>(def))
618         return addExp(kCastSF, e, v);
619       if (isa<arith::UIToFPOp>(def))
620         return addExp(kCastUF, e, v);
621       if (isa<arith::ExtSIOp>(def))
622         return addExp(kCastS, e, v);
623       if (isa<arith::ExtUIOp>(def))
624         return addExp(kCastU, e, v);
625       if (isa<arith::IndexCastOp>(def))
626         return addExp(kCastIdx, e, v);
627       if (isa<arith::TruncIOp>(def))
628         return addExp(kTruncI, e, v);
629       if (isa<arith::BitcastOp>(def))
630         return addExp(kBitCast, e, v);
631     }
632   }
633   // Construct binary operations if subexpressions can be built.
634   // See buildLattices() for an explanation of rejecting certain
635   // division and shift operations
636   if (def->getNumOperands() == 2) {
637     auto x = buildTensorExp(op, def->getOperand(0));
638     auto y = buildTensorExp(op, def->getOperand(1));
639     if (x.hasValue() && y.hasValue()) {
640       unsigned e0 = x.getValue();
641       unsigned e1 = y.getValue();
642       if (isa<arith::MulFOp>(def))
643         return addExp(kMulF, e0, e1);
644       if (isa<arith::MulIOp>(def))
645         return addExp(kMulI, e0, e1);
646       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
647         return addExp(kDivF, e0, e1);
648       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
649         return addExp(kDivS, e0, e1);
650       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
651         return addExp(kDivU, e0, e1);
652       if (isa<arith::AddFOp>(def))
653         return addExp(kAddF, e0, e1);
654       if (isa<arith::AddIOp>(def))
655         return addExp(kAddI, e0, e1);
656       if (isa<arith::SubFOp>(def))
657         return addExp(kSubF, e0, e1);
658       if (isa<arith::SubIOp>(def))
659         return addExp(kSubI, e0, e1);
660       if (isa<arith::AndIOp>(def))
661         return addExp(kAndI, e0, e1);
662       if (isa<arith::OrIOp>(def))
663         return addExp(kOrI, e0, e1);
664       if (isa<arith::XOrIOp>(def))
665         return addExp(kXorI, e0, e1);
666       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
667         return addExp(kShrS, e0, e1);
668       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
669         return addExp(kShrU, e0, e1);
670       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
671         return addExp(kShlI, e0, e1);
672     }
673   }
674   // Cannot build.
675   return None;
676 }
677 
678 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
679                        Value v0, Value v1) {
680   switch (tensorExps[e].kind) {
681   case kTensor:
682   case kInvariant:
683   case kIndex:
684     llvm_unreachable("unexpected non-op");
685   // Unary ops.
686   case kAbsF:
687     return rewriter.create<math::AbsOp>(loc, v0);
688   case kCeilF:
689     return rewriter.create<math::CeilOp>(loc, v0);
690   case kFloorF:
691     return rewriter.create<math::FloorOp>(loc, v0);
692   case kNegF:
693     return rewriter.create<arith::NegFOp>(loc, v0);
694   case kNegI: // no negi in std
695     return rewriter.create<arith::SubIOp>(
696         loc,
697         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
698                                            rewriter.getZeroAttr(v0.getType())),
699         v0);
700   case kTruncF:
701     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
702   case kExtF:
703     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
704   case kCastFS:
705     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
706   case kCastFU:
707     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
708   case kCastSF:
709     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
710   case kCastUF:
711     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
712   case kCastS:
713     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
714   case kCastU:
715     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
716   case kCastIdx:
717     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
718   case kTruncI:
719     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
720   case kBitCast:
721     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
722   // Binary ops.
723   case kMulF:
724     return rewriter.create<arith::MulFOp>(loc, v0, v1);
725   case kMulI:
726     return rewriter.create<arith::MulIOp>(loc, v0, v1);
727   case kDivF:
728     return rewriter.create<arith::DivFOp>(loc, v0, v1);
729   case kDivS:
730     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
731   case kDivU:
732     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
733   case kAddF:
734     return rewriter.create<arith::AddFOp>(loc, v0, v1);
735   case kAddI:
736     return rewriter.create<arith::AddIOp>(loc, v0, v1);
737   case kSubF:
738     return rewriter.create<arith::SubFOp>(loc, v0, v1);
739   case kSubI:
740     return rewriter.create<arith::SubIOp>(loc, v0, v1);
741   case kAndI:
742     return rewriter.create<arith::AndIOp>(loc, v0, v1);
743   case kOrI:
744     return rewriter.create<arith::OrIOp>(loc, v0, v1);
745   case kXorI:
746     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
747   case kShrS:
748     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
749   case kShrU:
750     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
751   case kShlI:
752     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
753   }
754   llvm_unreachable("unexpected expression kind in build");
755 }
756 
757 } // namespace sparse_tensor
758 } // namespace mlir
759