1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
11 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 
14 #include "mlir/IR/Operation.h"
15 #include "llvm/Support/Debug.h"
16 
17 namespace mlir {
18 namespace sparse_tensor {
19 
20 //===----------------------------------------------------------------------===//
21 // Constructors.
22 //===----------------------------------------------------------------------===//
23 
24 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
25     : kind(k), val(v), op(o) {
26   switch (kind) {
27   case kTensor:
28     assert(x != -1u && y == -1u && !v && !o);
29     tensor = x;
30     break;
31   case kInvariant:
32     assert(x == -1u && y == -1u && v && !o);
33     break;
34   case kIndex:
35     assert(x != -1u && y == -1u && !v && !o);
36     index = x;
37     break;
38   case kAbsF:
39   case kCeilF:
40   case kFloorF:
41   case kNegF:
42   case kNegI:
43     assert(x != -1u && y == -1u && !v && !o);
44     children.e0 = x;
45     children.e1 = y;
46     break;
47   case kTruncF:
48   case kExtF:
49   case kCastFS:
50   case kCastFU:
51   case kCastSF:
52   case kCastUF:
53   case kCastS:
54   case kCastU:
55   case kCastIdx:
56   case kTruncI:
57   case kBitCast:
58     assert(x != -1u && y == -1u && v && !o);
59     children.e0 = x;
60     children.e1 = y;
61     break;
62   case kBinaryBranch:
63     assert(x != -1u && y == -1u && !v && o);
64     children.e0 = x;
65     children.e1 = y;
66     break;
67   case kUnary:
68     // No assertion on y can be made, as the branching paths involve both
69     // a unary (mapSet) and binary (takeDisj) pathway.
70     assert(x != -1u && !v && o);
71     children.e0 = x;
72     children.e1 = y;
73     break;
74   case kBinary:
75     assert(x != -1u && y != -1u && !v && o);
76     children.e0 = x;
77     children.e1 = y;
78     break;
79   default:
80     assert(x != -1u && y != -1u && !v && !o);
81     children.e0 = x;
82     children.e1 = y;
83     break;
84   }
85 }
86 
87 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
88     : bits(n, false), simple(), exp(e) {
89   bits.set(b);
90 }
91 
92 LatPoint::LatPoint(const BitVector &b, unsigned e)
93     : bits(b), simple(), exp(e) {}
94 
95 //===----------------------------------------------------------------------===//
96 // Lattice methods.
97 //===----------------------------------------------------------------------===//
98 
99 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
100                         Operation *op) {
101   unsigned e = tensorExps.size();
102   tensorExps.push_back(TensorExp(k, e0, e1, v, op));
103   return e;
104 }
105 
106 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
107   assert(t < numTensors && i < numLoops);
108   unsigned p = latPoints.size();
109   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
110   return p;
111 }
112 
113 unsigned Merger::addSet() {
114   unsigned s = latSets.size();
115   latSets.emplace_back(SmallVector<unsigned, 16>());
116   return s;
117 }
118 
119 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
120                               Operation *op) {
121   unsigned p = latPoints.size();
122   BitVector nb = BitVector(latPoints[p0].bits);
123   nb |= latPoints[p1].bits;
124   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
125   latPoints.push_back(LatPoint(nb, e));
126   return p;
127 }
128 
129 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
130   unsigned s = addSet();
131   for (unsigned p0 : latSets[s0])
132     for (unsigned p1 : latSets[s1])
133       latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
134   return s;
135 }
136 
137 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
138   unsigned s = takeConj(kind, s0, s1, op);
139   // Followed by all in s0.
140   for (unsigned p : latSets[s0])
141     latSets[s].push_back(p);
142   // Map binary 0-y to unary -y.
143   // TODO: move this if-else logic into buildLattices
144   if (kind == kSubF)
145     s1 = mapSet(kNegF, s1);
146   else if (kind == kSubI)
147     s1 = mapSet(kNegI, s1);
148   // Followed by all in s1.
149   for (unsigned p : latSets[s1])
150     latSets[s].push_back(p);
151   return s;
152 }
153 
154 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
155                            bool includeLeft, Kind ltrans, Operation *opleft,
156                            bool includeRight, Kind rtrans, Operation *opright) {
157   unsigned s = takeConj(kind, s0, s1, orig);
158   // Left Region.
159   if (includeLeft) {
160     if (opleft)
161       s0 = mapSet(ltrans, s0, Value(), opleft);
162     for (unsigned p : latSets[s0])
163       latSets[s].push_back(p);
164   }
165   // Right Region.
166   if (includeRight) {
167     if (opright)
168       s1 = mapSet(rtrans, s1, Value(), opright);
169     for (unsigned p : latSets[s1])
170       latSets[s].push_back(p);
171   }
172   return s;
173 }
174 
175 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
176   assert(kAbsF <= kind && kind <= kUnary);
177   unsigned s = addSet();
178   for (unsigned p : latSets[s0]) {
179     unsigned e = addExp(kind, latPoints[p].exp, v, op);
180     latPoints.push_back(LatPoint(latPoints[p].bits, e));
181     latSets[s].push_back(latPoints.size() - 1);
182   }
183   return s;
184 }
185 
186 unsigned Merger::optimizeSet(unsigned s0) {
187   unsigned s = addSet();
188   assert(!latSets[s0].empty());
189   unsigned p0 = latSets[s0][0];
190   for (unsigned p1 : latSets[s0]) {
191     bool add = true;
192     if (p0 != p1) {
193       // Is this a straightforward copy?
194       unsigned e = latPoints[p1].exp;
195       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
196         continue;
197       // Conjunction already covered?
198       for (unsigned p2 : latSets[s]) {
199         assert(!latGT(p1, p2)); // Lj => Li would be bad
200         if (onlyDenseDiff(p2, p1)) {
201           add = false;
202           break;
203         }
204       }
205       assert(!add || latGT(p0, p1));
206     }
207     if (add)
208       latSets[s].push_back(p1);
209   }
210   for (unsigned p : latSets[s])
211     latPoints[p].simple = simplifyCond(s, p);
212   return s;
213 }
214 
215 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
216   // First determine if this lattice point is a *singleton*, i.e.,
217   // the last point in a lattice, no other is less than this one.
218   bool isSingleton = true;
219   for (unsigned p1 : latSets[s0]) {
220     if (p0 != p1 && latGT(p0, p1)) {
221       isSingleton = false;
222       break;
223     }
224   }
225   // Now apply the two basic rules.
226   BitVector simple = latPoints[p0].bits;
227   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
228   for (unsigned b = 0, be = simple.size(); b < be; b++) {
229     if (simple[b] && !isDim(b, kSparse)) {
230       if (reset)
231         simple.reset(b);
232       reset = true;
233     }
234   }
235   return simple;
236 }
237 
238 bool Merger::latGT(unsigned i, unsigned j) const {
239   const BitVector &bitsi = latPoints[i].bits;
240   const BitVector &bitsj = latPoints[j].bits;
241   assert(bitsi.size() == bitsj.size());
242   if (bitsi.count() > bitsj.count()) {
243     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
244       if (bitsj[b] && !bitsi[b])
245         return false;
246     return true;
247   }
248   return false;
249 }
250 
251 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
252   BitVector tmp = latPoints[j].bits;
253   tmp ^= latPoints[i].bits;
254   return !hasAnyDimOf(tmp, kSparse);
255 }
256 
257 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
258   for (unsigned b = 0, be = bits.size(); b < be; b++)
259     if (bits[b] && isDim(b, d))
260       return true;
261   return false;
262 }
263 
264 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
265   switch (tensorExps[e].kind) {
266   case kTensor:
267     return tensorExps[e].tensor == t;
268   case kAbsF:
269   case kCeilF:
270   case kFloorF:
271   case kNegF:
272   case kNegI:
273   case kTruncF:
274   case kExtF:
275   case kCastFS:
276   case kCastFU:
277   case kCastSF:
278   case kCastUF:
279   case kCastS:
280   case kCastU:
281   case kCastIdx:
282   case kTruncI:
283   case kBitCast:
284     return isSingleCondition(t, tensorExps[e].children.e0);
285   case kDivF: // note: x / c only
286   case kDivS:
287   case kDivU:
288     assert(!maybeZero(tensorExps[e].children.e1));
289     return isSingleCondition(t, tensorExps[e].children.e0);
290   case kShrS: // note: x >> inv only
291   case kShrU:
292   case kShlI:
293     assert(isInvariant(tensorExps[e].children.e1));
294     return isSingleCondition(t, tensorExps[e].children.e0);
295   case kMulF:
296   case kMulI:
297   case kAndI:
298     if (isSingleCondition(t, tensorExps[e].children.e0))
299       return isSingleCondition(t, tensorExps[e].children.e1) ||
300              isInvariant(tensorExps[e].children.e1);
301     if (isSingleCondition(t, tensorExps[e].children.e1))
302       return isInvariant(tensorExps[e].children.e0);
303     return false;
304   case kAddF:
305   case kAddI:
306     return isSingleCondition(t, tensorExps[e].children.e0) &&
307            isSingleCondition(t, tensorExps[e].children.e1);
308   default:
309     return false;
310   }
311 }
312 
313 #ifndef NDEBUG
314 
315 //===----------------------------------------------------------------------===//
316 // Print methods (for debugging).
317 //===----------------------------------------------------------------------===//
318 
319 static const char *kindToOpSymbol(Kind kind) {
320   switch (kind) {
321   case kTensor:
322     return "tensor";
323   case kInvariant:
324     return "invariant";
325   case kIndex:
326     return "index";
327   case kAbsF:
328     return "abs";
329   case kCeilF:
330     return "ceil";
331   case kFloorF:
332     return "floor";
333   case kNegF:
334     return "-";
335   case kNegI:
336     return "-";
337   case kTruncF:
338   case kExtF:
339   case kCastFS:
340   case kCastFU:
341   case kCastSF:
342   case kCastUF:
343   case kCastS:
344   case kCastU:
345   case kCastIdx:
346   case kTruncI:
347   case kBitCast:
348     return "cast";
349   case kBinaryBranch:
350     return "binary_branch";
351   case kUnary:
352     return "unary";
353   case kMulF:
354     return "*";
355   case kMulI:
356     return "*";
357   case kDivF:
358     return "/";
359   case kDivS:
360     return "/";
361   case kDivU:
362     return "/";
363   case kAddF:
364     return "+";
365   case kAddI:
366     return "+";
367   case kSubF:
368     return "-";
369   case kSubI:
370     return "-";
371   case kAndI:
372     return "&";
373   case kOrI:
374     return "|";
375   case kXorI:
376     return "^";
377   case kShrS:
378     return "a>>";
379   case kShrU:
380     return ">>";
381   case kShlI:
382     return "<<";
383   case kBinary:
384     return "binary";
385   }
386   llvm_unreachable("unexpected kind for symbol");
387 }
388 
389 void Merger::dumpExp(unsigned e) const {
390   switch (tensorExps[e].kind) {
391   case kTensor:
392     if (tensorExps[e].tensor == syntheticTensor)
393       llvm::dbgs() << "synthetic_";
394     else if (tensorExps[e].tensor == outTensor)
395       llvm::dbgs() << "output_";
396     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
397     break;
398   case kInvariant:
399     llvm::dbgs() << "invariant";
400     break;
401   case kIndex:
402     llvm::dbgs() << "index_" << tensorExps[e].index;
403     break;
404   case kAbsF:
405   case kCeilF:
406   case kFloorF:
407   case kNegF:
408   case kNegI:
409   case kTruncF:
410   case kExtF:
411   case kCastFS:
412   case kCastFU:
413   case kCastSF:
414   case kCastUF:
415   case kCastS:
416   case kCastU:
417   case kCastIdx:
418   case kTruncI:
419   case kBitCast:
420     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
421     dumpExp(tensorExps[e].children.e0);
422     break;
423   default:
424     llvm::dbgs() << "(";
425     dumpExp(tensorExps[e].children.e0);
426     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
427     dumpExp(tensorExps[e].children.e1);
428     llvm::dbgs() << ")";
429   }
430 }
431 
432 void Merger::dumpLat(unsigned p) const {
433   llvm::dbgs() << "lat(";
434   dumpBits(latPoints[p].bits);
435   llvm::dbgs() << " :";
436   dumpBits(latPoints[p].simple);
437   llvm::dbgs() << " : ";
438   dumpExp(latPoints[p].exp);
439   llvm::dbgs() << " )\n";
440 }
441 
442 void Merger::dumpSet(unsigned s) const {
443   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
444   for (unsigned p : latSets[s]) {
445     llvm::dbgs() << "  ";
446     dumpLat(p);
447   }
448   llvm::dbgs() << "}\n";
449 }
450 
451 void Merger::dumpBits(const BitVector &bits) const {
452   for (unsigned b = 0, be = bits.size(); b < be; b++) {
453     if (bits[b]) {
454       unsigned t = tensor(b);
455       unsigned i = index(b);
456       llvm::dbgs() << " i_" << t << "_" << i << "_";
457       switch (dims[t][i]) {
458       case kSparse:
459         llvm::dbgs() << "S";
460         break;
461       case kDense:
462         llvm::dbgs() << "D";
463         break;
464       case kSingle:
465         llvm::dbgs() << "T";
466         break;
467       case kUndef:
468         llvm::dbgs() << "U";
469         break;
470       }
471     }
472   }
473 }
474 
475 #endif // NDEBUG
476 
477 //===----------------------------------------------------------------------===//
478 // Builder methods.
479 //===----------------------------------------------------------------------===//
480 
481 unsigned Merger::buildLattices(unsigned e, unsigned i) {
482   Kind kind = tensorExps[e].kind;
483   switch (kind) {
484   case kTensor:
485   case kInvariant:
486   case kIndex: {
487     // Either the index is really used in the tensor expression, or it is
488     // set to the undefined index in that dimension. An invariant expression,
489     // a proper index value, and a truly dynamic sparse output tensor are set
490     // to a synthetic tensor with undefined indices only to ensure the
491     // iteration space is not skipped as a result of their contents.
492     unsigned s = addSet();
493     unsigned t = syntheticTensor;
494     if (kind == kTensor) {
495       t = tensorExps[e].tensor;
496       if (hasSparseOut && t == outTensor)
497         t = syntheticTensor;
498     }
499     latSets[s].push_back(addLat(t, i, e));
500     return s;
501   }
502   case kAbsF:
503   case kCeilF:
504   case kFloorF:
505   case kNegF:
506   case kNegI:
507   case kTruncF:
508   case kExtF:
509   case kCastFS:
510   case kCastFU:
511   case kCastSF:
512   case kCastUF:
513   case kCastS:
514   case kCastU:
515   case kCastIdx:
516   case kTruncI:
517   case kBitCast:
518     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
519     // lattice set of the operand through the operator into a new set.
520     //
521     //  -y|!y | y |
522     //  --+---+---+
523     //    | 0 |-y |
524     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
525                   tensorExps[e].val);
526   case kBinaryBranch:
527     // The left or right half of a binary operation which has already
528     // been split into separate operations for each region.
529     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
530                   tensorExps[e].op);
531   case kUnary:
532     // A custom unary operation.
533     //
534     //  op y|    !y    |     y      |
535     //  ----+----------+------------+
536     //      | absent() | present(y) |
537     {
538       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
539       UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
540       Region &absentRegion = unop.absentRegion();
541 
542       if (absentRegion.empty()) {
543         // Simple mapping over existing values.
544         return mapSet(kind, child0, Value(), unop);
545       } else {
546         // Use a disjunction with `unop` on the left and the absent value as an
547         // invariant on the right.
548         Block &absentBlock = absentRegion.front();
549         YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
550         Value absentVal = absentYield.result();
551         unsigned rhs = addExp(kInvariant, absentVal);
552         return takeDisj(kind, child0, buildLattices(rhs, i), unop);
553       }
554     }
555   case kMulF:
556   case kMulI:
557   case kAndI:
558     // A multiplicative operation only needs to be performed
559     // for the conjunction of sparse iteration spaces.
560     //
561     //  x*y|!y | y |
562     //  ---+---+---+
563     //  !x | 0 | 0 |
564     //   x | 0 |x*y|
565     return takeConj(kind, // take binary conjunction
566                     buildLattices(tensorExps[e].children.e0, i),
567                     buildLattices(tensorExps[e].children.e1, i));
568   case kDivF:
569   case kDivS:
570   case kDivU:
571     // A division is tricky, since 0/0, 0/c, c/0 all have
572     // specific outcomes for floating-point and integers.
573     // Thus, we need to traverse the full iteration space.
574     //
575     //  x/y|!y | y |
576     //  ---+---+---+
577     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
578     //   x |x/0|x/y|  INT: x/0=exception for any x
579     //
580     // TODO: for now we "fixed" this by only accepting x/c cases
581     //       during expression building, so that the conjunction
582     //       rules applies (viz. x/c = x*(1/c) as far as lattice
583     //       construction is concerned).
584     assert(!maybeZero(tensorExps[e].children.e1));
585     return takeConj(kind, // take binary conjunction
586                     buildLattices(tensorExps[e].children.e0, i),
587                     buildLattices(tensorExps[e].children.e1, i));
588   case kAddF:
589   case kAddI:
590   case kSubF:
591   case kSubI:
592   case kOrI:
593   case kXorI:
594     // An additive operation needs to be performed
595     // for the disjunction of sparse iteration spaces.
596     //
597     //  x+y|!y | y |    x-y|!y | y |
598     //  ---+---+---+    ---+---+---+
599     //  !x | 0 | y |    !x | 0 |-y |
600     //   x | x |x+y|     x | x |x-y|
601     return takeDisj(kind, // take binary disjunction
602                     buildLattices(tensorExps[e].children.e0, i),
603                     buildLattices(tensorExps[e].children.e1, i));
604   case kShrS:
605   case kShrU:
606   case kShlI:
607     // A shift operation by an invariant amount (viz. tensor expressions
608     // can only occur at the left-hand-side of the operator) can be handled
609     // with the conjuction rule.
610     assert(isInvariant(tensorExps[e].children.e1));
611     return takeConj(kind, // take binary conjunction
612                     buildLattices(tensorExps[e].children.e0, i),
613                     buildLattices(tensorExps[e].children.e1, i));
614   case kBinary:
615     // A custom binary operation.
616     //
617     //  x op y|   !y    |       y      |
618     //  ------+---------+--------------+
619     //    !x  |  empty  |   right(y)   |
620     //     x  | left(x) | overlap(x,y) |
621     {
622       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
623       unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
624       BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
625       Region &leftRegion = binop.leftRegion();
626       Region &rightRegion = binop.rightRegion();
627       // Left Region.
628       Operation *leftYield = nullptr;
629       if (!leftRegion.empty()) {
630         Block &leftBlock = leftRegion.front();
631         leftYield = leftBlock.getTerminator();
632       }
633       // Right Region.
634       Operation *rightYield = nullptr;
635       if (!rightRegion.empty()) {
636         Block &rightBlock = rightRegion.front();
637         rightYield = rightBlock.getTerminator();
638       }
639       bool includeLeft = binop.left_identity() || !leftRegion.empty();
640       bool includeRight = binop.right_identity() || !rightRegion.empty();
641       return takeCombi(kBinary, child0, child1, binop, includeLeft,
642                        kBinaryBranch, leftYield, includeRight, kBinaryBranch,
643                        rightYield);
644     }
645   }
646   llvm_unreachable("unexpected expression kind");
647 }
648 
649 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
650   Operation *yield = op.region().front().getTerminator();
651   return buildTensorExp(op, yield->getOperand(0));
652 }
653 
654 /// Only returns false if we are certain this is a nonzero.
655 bool Merger::maybeZero(unsigned e) const {
656   if (tensorExps[e].kind == kInvariant) {
657     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
658       return c.value() == 0;
659     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
660       return c.value().isZero();
661   }
662   return true;
663 }
664 
665 bool Merger::isInvariant(unsigned e) const {
666   return tensorExps[e].kind == kInvariant;
667 }
668 
669 Type Merger::inferType(unsigned e, Value src) {
670   // Obtain the destination type from the cast node.
671   Type dtp = tensorExps[e].val.getType();
672   // Inspect source type. For vector types, apply the same
673   // vectorization to the destination type.
674   if (auto vtp = src.getType().dyn_cast<VectorType>())
675     return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
676   return dtp;
677 }
678 
679 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
680   if (auto arg = v.dyn_cast<BlockArgument>()) {
681     unsigned argN = arg.getArgNumber();
682     // Any argument of the generic op that is not marked as a scalar
683     // argument is considered a tensor, indexed by the implicit loop
684     // bounds. This includes rank-0 tensor arguments.
685     if (arg.getOwner()->getParentOp() == op) {
686       OpOperand *t = op.getInputAndOutputOperands()[argN];
687       if (!op.isScalar(t))
688         return addExp(kTensor, argN);
689       v = t->get(); // get scalar value
690     }
691     // Any other argument (marked as scalar argument for the generic op
692     // or belonging to an enveloping op) is considered invariant.
693     return addExp(kInvariant, v);
694   }
695   // Something defined outside is invariant.
696   Operation *def = v.getDefiningOp();
697   if (def->getBlock() != &op.region().front())
698     return addExp(kInvariant, v);
699   // Construct index operations.
700   if (def->getNumOperands() == 0) {
701     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
702       return addExp(kIndex, indexOp.dim());
703   }
704   // Construct unary operations if subexpression can be built.
705   if (def->getNumOperands() == 1) {
706     auto x = buildTensorExp(op, def->getOperand(0));
707     if (x.hasValue()) {
708       unsigned e = x.getValue();
709       if (isa<math::AbsOp>(def))
710         return addExp(kAbsF, e);
711       if (isa<math::CeilOp>(def))
712         return addExp(kCeilF, e);
713       if (isa<math::FloorOp>(def))
714         return addExp(kFloorF, e);
715       if (isa<arith::NegFOp>(def))
716         return addExp(kNegF, e); // no negi in std
717       if (isa<arith::TruncFOp>(def))
718         return addExp(kTruncF, e, v);
719       if (isa<arith::ExtFOp>(def))
720         return addExp(kExtF, e, v);
721       if (isa<arith::FPToSIOp>(def))
722         return addExp(kCastFS, e, v);
723       if (isa<arith::FPToUIOp>(def))
724         return addExp(kCastFU, e, v);
725       if (isa<arith::SIToFPOp>(def))
726         return addExp(kCastSF, e, v);
727       if (isa<arith::UIToFPOp>(def))
728         return addExp(kCastUF, e, v);
729       if (isa<arith::ExtSIOp>(def))
730         return addExp(kCastS, e, v);
731       if (isa<arith::ExtUIOp>(def))
732         return addExp(kCastU, e, v);
733       if (isa<arith::IndexCastOp>(def))
734         return addExp(kCastIdx, e, v);
735       if (isa<arith::TruncIOp>(def))
736         return addExp(kTruncI, e, v);
737       if (isa<arith::BitcastOp>(def))
738         return addExp(kBitCast, e, v);
739       if (isa<sparse_tensor::UnaryOp>(def))
740         return addExp(kUnary, e, Value(), def);
741     }
742   }
743   // Construct binary operations if subexpressions can be built.
744   // See buildLattices() for an explanation of rejecting certain
745   // division and shift operations
746   if (def->getNumOperands() == 2) {
747     auto x = buildTensorExp(op, def->getOperand(0));
748     auto y = buildTensorExp(op, def->getOperand(1));
749     if (x.hasValue() && y.hasValue()) {
750       unsigned e0 = x.getValue();
751       unsigned e1 = y.getValue();
752       if (isa<arith::MulFOp>(def))
753         return addExp(kMulF, e0, e1);
754       if (isa<arith::MulIOp>(def))
755         return addExp(kMulI, e0, e1);
756       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
757         return addExp(kDivF, e0, e1);
758       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
759         return addExp(kDivS, e0, e1);
760       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
761         return addExp(kDivU, e0, e1);
762       if (isa<arith::AddFOp>(def))
763         return addExp(kAddF, e0, e1);
764       if (isa<arith::AddIOp>(def))
765         return addExp(kAddI, e0, e1);
766       if (isa<arith::SubFOp>(def))
767         return addExp(kSubF, e0, e1);
768       if (isa<arith::SubIOp>(def))
769         return addExp(kSubI, e0, e1);
770       if (isa<arith::AndIOp>(def))
771         return addExp(kAndI, e0, e1);
772       if (isa<arith::OrIOp>(def))
773         return addExp(kOrI, e0, e1);
774       if (isa<arith::XOrIOp>(def))
775         return addExp(kXorI, e0, e1);
776       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
777         return addExp(kShrS, e0, e1);
778       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
779         return addExp(kShrU, e0, e1);
780       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
781         return addExp(kShlI, e0, e1);
782       if (isa<sparse_tensor::BinaryOp>(def))
783         return addExp(kBinary, e0, e1, Value(), def);
784     }
785   }
786   // Cannot build.
787   return None;
788 }
789 
790 static Value insertYieldOp(PatternRewriter &rewriter, Location loc,
791                            Region &region, ValueRange vals) {
792   // Make a clone of overlap region.
793   Region tmpRegion;
794   BlockAndValueMapping mapper;
795   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
796   Block &clonedBlock = tmpRegion.front();
797   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
798   // Merge cloned block and return yield value.
799   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
800   rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
801   Value val = clonedYield.result();
802   rewriter.eraseOp(clonedYield);
803   rewriter.eraseOp(placeholder);
804   return val;
805 }
806 
807 static Value buildUnaryPresent(PatternRewriter &rewriter, Location loc,
808                                Operation *op, Value v0) {
809   if (!v0)
810     // Empty input value must be propagated.
811     return Value();
812   UnaryOp unop = cast<UnaryOp>(op);
813   Region &presentRegion = unop.presentRegion();
814   if (presentRegion.empty())
815     // Uninitialized Value() will be interpreted as missing data in the
816     // output.
817     return Value();
818   return insertYieldOp(rewriter, loc, presentRegion, {v0});
819 }
820 
821 static Value buildBinaryOverlap(PatternRewriter &rewriter, Location loc,
822                                 Operation *op, Value v0, Value v1) {
823   if (!v0 || !v1)
824     // Empty input values must be propagated.
825     return Value();
826   BinaryOp binop = cast<BinaryOp>(op);
827   Region &overlapRegion = binop.overlapRegion();
828   if (overlapRegion.empty())
829     // Uninitialized Value() will be interpreted as missing data in the
830     // output.
831     return Value();
832   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
833 }
834 
835 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
836                        Value v0, Value v1) {
837   switch (tensorExps[e].kind) {
838   case kTensor:
839   case kInvariant:
840   case kIndex:
841     llvm_unreachable("unexpected non-op");
842   // Unary ops.
843   case kAbsF:
844     return rewriter.create<math::AbsOp>(loc, v0);
845   case kCeilF:
846     return rewriter.create<math::CeilOp>(loc, v0);
847   case kFloorF:
848     return rewriter.create<math::FloorOp>(loc, v0);
849   case kNegF:
850     return rewriter.create<arith::NegFOp>(loc, v0);
851   case kNegI: // no negi in std
852     return rewriter.create<arith::SubIOp>(
853         loc,
854         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
855                                            rewriter.getZeroAttr(v0.getType())),
856         v0);
857   case kTruncF:
858     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
859   case kExtF:
860     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
861   case kCastFS:
862     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
863   case kCastFU:
864     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
865   case kCastSF:
866     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
867   case kCastUF:
868     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
869   case kCastS:
870     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
871   case kCastU:
872     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
873   case kCastIdx:
874     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
875   case kTruncI:
876     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
877   case kBitCast:
878     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
879   // Binary ops.
880   case kMulF:
881     return rewriter.create<arith::MulFOp>(loc, v0, v1);
882   case kMulI:
883     return rewriter.create<arith::MulIOp>(loc, v0, v1);
884   case kDivF:
885     return rewriter.create<arith::DivFOp>(loc, v0, v1);
886   case kDivS:
887     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
888   case kDivU:
889     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
890   case kAddF:
891     return rewriter.create<arith::AddFOp>(loc, v0, v1);
892   case kAddI:
893     return rewriter.create<arith::AddIOp>(loc, v0, v1);
894   case kSubF:
895     return rewriter.create<arith::SubFOp>(loc, v0, v1);
896   case kSubI:
897     return rewriter.create<arith::SubIOp>(loc, v0, v1);
898   case kAndI:
899     return rewriter.create<arith::AndIOp>(loc, v0, v1);
900   case kOrI:
901     return rewriter.create<arith::OrIOp>(loc, v0, v1);
902   case kXorI:
903     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
904   case kShrS:
905     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
906   case kShrU:
907     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
908   case kShlI:
909     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
910   // Semiring ops with custom logic.
911   case kBinaryBranch:
912     return insertYieldOp(rewriter, loc,
913                          *tensorExps[e].op->getBlock()->getParent(), {v0});
914   case kUnary:
915     return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
916   case kBinary:
917     return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
918   }
919   llvm_unreachable("unexpected expression kind in build");
920 }
921 
922 } // namespace sparse_tensor
923 } // namespace mlir
924