1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 
18 namespace mlir {
19 namespace sparse_tensor {
20 
21 //===----------------------------------------------------------------------===//
22 // Constructors.
23 //===----------------------------------------------------------------------===//
24 
25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
26     : kind(k), val(v), op(o) {
27   switch (kind) {
28   case kTensor:
29     assert(x != -1u && y == -1u && !v && !o);
30     tensor = x;
31     break;
32   case kInvariant:
33     assert(x == -1u && y == -1u && v && !o);
34     break;
35   case kIndex:
36     assert(x != -1u && y == -1u && !v && !o);
37     index = x;
38     break;
39   case kAbsF:
40   case kCeilF:
41   case kFloorF:
42   case kSqrtF:
43   case kExpm1F:
44   case kLog1pF:
45   case kSinF:
46   case kTanhF:
47   case kNegF:
48   case kNegI:
49     assert(x != -1u && y == -1u && !v && !o);
50     children.e0 = x;
51     children.e1 = y;
52     break;
53   case kTruncF:
54   case kExtF:
55   case kCastFS:
56   case kCastFU:
57   case kCastSF:
58   case kCastUF:
59   case kCastS:
60   case kCastU:
61   case kCastIdx:
62   case kTruncI:
63   case kBitCast:
64     assert(x != -1u && y == -1u && v && !o);
65     children.e0 = x;
66     children.e1 = y;
67     break;
68   case kBinaryBranch:
69     assert(x != -1u && y == -1u && !v && o);
70     children.e0 = x;
71     children.e1 = y;
72     break;
73   case kUnary:
74     // No assertion on y can be made, as the branching paths involve both
75     // a unary (mapSet) and binary (takeDisj) pathway.
76     assert(x != -1u && !v && o);
77     children.e0 = x;
78     children.e1 = y;
79     break;
80   case kBinary:
81     assert(x != -1u && y != -1u && !v && o);
82     children.e0 = x;
83     children.e1 = y;
84     break;
85   default:
86     assert(x != -1u && y != -1u && !v && !o);
87     children.e0 = x;
88     children.e1 = y;
89     break;
90   }
91 }
92 
93 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
94     : bits(n, false), simple(), exp(e) {
95   bits.set(b);
96 }
97 
98 LatPoint::LatPoint(const BitVector &b, unsigned e)
99     : bits(b), simple(), exp(e) {}
100 
101 //===----------------------------------------------------------------------===//
102 // Lattice methods.
103 //===----------------------------------------------------------------------===//
104 
105 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
106                         Operation *op) {
107   unsigned e = tensorExps.size();
108   tensorExps.push_back(TensorExp(k, e0, e1, v, op));
109   return e;
110 }
111 
112 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
113   assert(t < numTensors && i < numLoops);
114   unsigned p = latPoints.size();
115   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
116   return p;
117 }
118 
119 unsigned Merger::addSet() {
120   unsigned s = latSets.size();
121   latSets.emplace_back(SmallVector<unsigned, 16>());
122   return s;
123 }
124 
125 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
126                               Operation *op) {
127   unsigned p = latPoints.size();
128   BitVector nb = BitVector(latPoints[p0].bits);
129   nb |= latPoints[p1].bits;
130   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
131   latPoints.push_back(LatPoint(nb, e));
132   return p;
133 }
134 
135 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
136   unsigned s = addSet();
137   for (unsigned p0 : latSets[s0])
138     for (unsigned p1 : latSets[s1])
139       latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
140   return s;
141 }
142 
143 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
144   unsigned s = takeConj(kind, s0, s1, op);
145   // Followed by all in s0.
146   for (unsigned p : latSets[s0])
147     latSets[s].push_back(p);
148   // Map binary 0-y to unary -y.
149   // TODO: move this if-else logic into buildLattices
150   if (kind == kSubF)
151     s1 = mapSet(kNegF, s1);
152   else if (kind == kSubI)
153     s1 = mapSet(kNegI, s1);
154   // Followed by all in s1.
155   for (unsigned p : latSets[s1])
156     latSets[s].push_back(p);
157   return s;
158 }
159 
160 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
161                            bool includeLeft, Kind ltrans, Operation *opleft,
162                            bool includeRight, Kind rtrans, Operation *opright) {
163   unsigned s = takeConj(kind, s0, s1, orig);
164   // Left Region.
165   if (includeLeft) {
166     if (opleft)
167       s0 = mapSet(ltrans, s0, Value(), opleft);
168     for (unsigned p : latSets[s0])
169       latSets[s].push_back(p);
170   }
171   // Right Region.
172   if (includeRight) {
173     if (opright)
174       s1 = mapSet(rtrans, s1, Value(), opright);
175     for (unsigned p : latSets[s1])
176       latSets[s].push_back(p);
177   }
178   return s;
179 }
180 
181 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
182   assert(kAbsF <= kind && kind <= kUnary);
183   unsigned s = addSet();
184   for (unsigned p : latSets[s0]) {
185     unsigned e = addExp(kind, latPoints[p].exp, v, op);
186     latPoints.push_back(LatPoint(latPoints[p].bits, e));
187     latSets[s].push_back(latPoints.size() - 1);
188   }
189   return s;
190 }
191 
192 unsigned Merger::optimizeSet(unsigned s0) {
193   unsigned s = addSet();
194   assert(!latSets[s0].empty());
195   unsigned p0 = latSets[s0][0];
196   for (unsigned p1 : latSets[s0]) {
197     bool add = true;
198     if (p0 != p1) {
199       // Is this a straightforward copy?
200       unsigned e = latPoints[p1].exp;
201       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
202         continue;
203       // Conjunction already covered?
204       for (unsigned p2 : latSets[s]) {
205         assert(!latGT(p1, p2)); // Lj => Li would be bad
206         if (onlyDenseDiff(p2, p1)) {
207           add = false;
208           break;
209         }
210       }
211       assert(!add || latGT(p0, p1));
212     }
213     if (add)
214       latSets[s].push_back(p1);
215   }
216   for (unsigned p : latSets[s])
217     latPoints[p].simple = simplifyCond(s, p);
218   return s;
219 }
220 
221 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
222   // First determine if this lattice point is a *singleton*, i.e.,
223   // the last point in a lattice, no other is less than this one.
224   bool isSingleton = true;
225   for (unsigned p1 : latSets[s0]) {
226     if (p0 != p1 && latGT(p0, p1)) {
227       isSingleton = false;
228       break;
229     }
230   }
231   // Now apply the two basic rules.
232   BitVector simple = latPoints[p0].bits;
233   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
234   for (unsigned b = 0, be = simple.size(); b < be; b++) {
235     if (simple[b] && !isDim(b, kSparse)) {
236       if (reset)
237         simple.reset(b);
238       reset = true;
239     }
240   }
241   return simple;
242 }
243 
244 bool Merger::latGT(unsigned i, unsigned j) const {
245   const BitVector &bitsi = latPoints[i].bits;
246   const BitVector &bitsj = latPoints[j].bits;
247   assert(bitsi.size() == bitsj.size());
248   if (bitsi.count() > bitsj.count()) {
249     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
250       if (bitsj[b] && !bitsi[b])
251         return false;
252     return true;
253   }
254   return false;
255 }
256 
257 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
258   BitVector tmp = latPoints[j].bits;
259   tmp ^= latPoints[i].bits;
260   return !hasAnyDimOf(tmp, kSparse);
261 }
262 
263 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
264   for (unsigned b = 0, be = bits.size(); b < be; b++)
265     if (bits[b] && isDim(b, d))
266       return true;
267   return false;
268 }
269 
270 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
271   switch (tensorExps[e].kind) {
272   case kTensor:
273     return tensorExps[e].tensor == t;
274   case kAbsF:
275   case kCeilF:
276   case kFloorF:
277   case kSqrtF:
278   case kExpm1F:
279   case kLog1pF:
280   case kSinF:
281   case kTanhF:
282   case kNegF:
283   case kNegI:
284   case kTruncF:
285   case kExtF:
286   case kCastFS:
287   case kCastFU:
288   case kCastSF:
289   case kCastUF:
290   case kCastS:
291   case kCastU:
292   case kCastIdx:
293   case kTruncI:
294   case kBitCast:
295     return isSingleCondition(t, tensorExps[e].children.e0);
296   case kDivF: // note: x / c only
297   case kDivS:
298   case kDivU:
299     assert(!maybeZero(tensorExps[e].children.e1));
300     return isSingleCondition(t, tensorExps[e].children.e0);
301   case kShrS: // note: x >> inv only
302   case kShrU:
303   case kShlI:
304     assert(isInvariant(tensorExps[e].children.e1));
305     return isSingleCondition(t, tensorExps[e].children.e0);
306   case kMulF:
307   case kMulC:
308   case kMulI:
309   case kAndI:
310     if (isSingleCondition(t, tensorExps[e].children.e0))
311       return isSingleCondition(t, tensorExps[e].children.e1) ||
312              isInvariant(tensorExps[e].children.e1);
313     if (isSingleCondition(t, tensorExps[e].children.e1))
314       return isInvariant(tensorExps[e].children.e0);
315     return false;
316   case kAddF:
317   case kAddC:
318   case kAddI:
319     return isSingleCondition(t, tensorExps[e].children.e0) &&
320            isSingleCondition(t, tensorExps[e].children.e1);
321   default:
322     return false;
323   }
324 }
325 
326 #ifndef NDEBUG
327 
328 //===----------------------------------------------------------------------===//
329 // Print methods (for debugging).
330 //===----------------------------------------------------------------------===//
331 
332 static const char *kindToOpSymbol(Kind kind) {
333   switch (kind) {
334   case kTensor:
335     return "tensor";
336   case kInvariant:
337     return "invariant";
338   case kIndex:
339     return "index";
340   case kAbsF:
341     return "abs";
342   case kCeilF:
343     return "ceil";
344   case kFloorF:
345     return "floor";
346   case kSqrtF:
347     return "sqrt";
348   case kExpm1F:
349     return "expm1";
350   case kLog1pF:
351     return "log1p";
352   case kSinF:
353     return "sin";
354   case kTanhF:
355     return "tanh";
356   case kNegF:
357     return "-";
358   case kNegI:
359     return "-";
360   case kTruncF:
361   case kExtF:
362   case kCastFS:
363   case kCastFU:
364   case kCastSF:
365   case kCastUF:
366   case kCastS:
367   case kCastU:
368   case kCastIdx:
369   case kTruncI:
370   case kBitCast:
371     return "cast";
372   case kBinaryBranch:
373     return "binary_branch";
374   case kUnary:
375     return "unary";
376   case kMulF:
377   case kMulC:
378   case kMulI:
379     return "*";
380   case kDivF:
381   case kDivS:
382   case kDivU:
383     return "/";
384   case kAddF:
385   case kAddC:
386   case kAddI:
387     return "+";
388   case kSubF:
389   case kSubI:
390     return "-";
391   case kAndI:
392     return "&";
393   case kOrI:
394     return "|";
395   case kXorI:
396     return "^";
397   case kShrS:
398     return "a>>";
399   case kShrU:
400     return ">>";
401   case kShlI:
402     return "<<";
403   case kBinary:
404     return "binary";
405   }
406   llvm_unreachable("unexpected kind for symbol");
407 }
408 
409 void Merger::dumpExp(unsigned e) const {
410   switch (tensorExps[e].kind) {
411   case kTensor:
412     if (tensorExps[e].tensor == syntheticTensor)
413       llvm::dbgs() << "synthetic_";
414     else if (tensorExps[e].tensor == outTensor)
415       llvm::dbgs() << "output_";
416     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
417     break;
418   case kInvariant:
419     llvm::dbgs() << "invariant";
420     break;
421   case kIndex:
422     llvm::dbgs() << "index_" << tensorExps[e].index;
423     break;
424   case kAbsF:
425   case kCeilF:
426   case kFloorF:
427   case kSqrtF:
428   case kExpm1F:
429   case kLog1pF:
430   case kSinF:
431   case kTanhF:
432   case kNegF:
433   case kNegI:
434   case kTruncF:
435   case kExtF:
436   case kCastFS:
437   case kCastFU:
438   case kCastSF:
439   case kCastUF:
440   case kCastS:
441   case kCastU:
442   case kCastIdx:
443   case kTruncI:
444   case kBitCast:
445     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
446     dumpExp(tensorExps[e].children.e0);
447     break;
448   default:
449     llvm::dbgs() << "(";
450     dumpExp(tensorExps[e].children.e0);
451     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
452     dumpExp(tensorExps[e].children.e1);
453     llvm::dbgs() << ")";
454   }
455 }
456 
457 void Merger::dumpLat(unsigned p) const {
458   llvm::dbgs() << "lat(";
459   dumpBits(latPoints[p].bits);
460   llvm::dbgs() << " :";
461   dumpBits(latPoints[p].simple);
462   llvm::dbgs() << " : ";
463   dumpExp(latPoints[p].exp);
464   llvm::dbgs() << " )\n";
465 }
466 
467 void Merger::dumpSet(unsigned s) const {
468   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
469   for (unsigned p : latSets[s]) {
470     llvm::dbgs() << "  ";
471     dumpLat(p);
472   }
473   llvm::dbgs() << "}\n";
474 }
475 
476 void Merger::dumpBits(const BitVector &bits) const {
477   for (unsigned b = 0, be = bits.size(); b < be; b++) {
478     if (bits[b]) {
479       unsigned t = tensor(b);
480       unsigned i = index(b);
481       llvm::dbgs() << " i_" << t << "_" << i << "_";
482       switch (dims[t][i]) {
483       case kSparse:
484         llvm::dbgs() << "S";
485         break;
486       case kDense:
487         llvm::dbgs() << "D";
488         break;
489       case kSingle:
490         llvm::dbgs() << "T";
491         break;
492       case kUndef:
493         llvm::dbgs() << "U";
494         break;
495       }
496     }
497   }
498 }
499 
500 #endif // NDEBUG
501 
502 //===----------------------------------------------------------------------===//
503 // Builder methods.
504 //===----------------------------------------------------------------------===//
505 
506 unsigned Merger::buildLattices(unsigned e, unsigned i) {
507   Kind kind = tensorExps[e].kind;
508   switch (kind) {
509   case kTensor:
510   case kInvariant:
511   case kIndex: {
512     // Either the index is really used in the tensor expression, or it is
513     // set to the undefined index in that dimension. An invariant expression,
514     // a proper index value, and a truly dynamic sparse output tensor are set
515     // to a synthetic tensor with undefined indices only to ensure the
516     // iteration space is not skipped as a result of their contents.
517     unsigned s = addSet();
518     unsigned t = syntheticTensor;
519     if (kind == kTensor) {
520       t = tensorExps[e].tensor;
521       if (hasSparseOut && t == outTensor)
522         t = syntheticTensor;
523     }
524     latSets[s].push_back(addLat(t, i, e));
525     return s;
526   }
527   case kAbsF:
528   case kCeilF:
529   case kFloorF:
530   case kSqrtF:
531   case kExpm1F:
532   case kLog1pF:
533   case kSinF:
534   case kTanhF:
535   case kNegF:
536   case kNegI:
537   case kTruncF:
538   case kExtF:
539   case kCastFS:
540   case kCastFU:
541   case kCastSF:
542   case kCastUF:
543   case kCastS:
544   case kCastU:
545   case kCastIdx:
546   case kTruncI:
547   case kBitCast:
548     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
549     // lattice set of the operand through the operator into a new set.
550     //
551     //  -y|!y | y |
552     //  --+---+---+
553     //    | 0 |-y |
554     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
555                   tensorExps[e].val);
556   case kBinaryBranch:
557     // The left or right half of a binary operation which has already
558     // been split into separate operations for each region.
559     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
560                   tensorExps[e].op);
561   case kUnary:
562     // A custom unary operation.
563     //
564     //  op y|    !y    |     y      |
565     //  ----+----------+------------+
566     //      | absent() | present(y) |
567     {
568       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
569       UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
570       Region &absentRegion = unop.absentRegion();
571 
572       if (absentRegion.empty()) {
573         // Simple mapping over existing values.
574         return mapSet(kind, child0, Value(), unop);
575       } // Use a disjunction with `unop` on the left and the absent value as an
576       // invariant on the right.
577       Block &absentBlock = absentRegion.front();
578       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
579       Value absentVal = absentYield.result();
580       unsigned rhs = addExp(kInvariant, absentVal);
581       return takeDisj(kind, child0, buildLattices(rhs, i), unop);
582     }
583   case kMulF:
584   case kMulC:
585   case kMulI:
586   case kAndI:
587     // A multiplicative operation only needs to be performed
588     // for the conjunction of sparse iteration spaces.
589     //
590     //  x*y|!y | y |
591     //  ---+---+---+
592     //  !x | 0 | 0 |
593     //   x | 0 |x*y|
594     //
595     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
596     return takeConj(kind, // take binary conjunction
597                     buildLattices(tensorExps[e].children.e0, i),
598                     buildLattices(tensorExps[e].children.e1, i));
599   case kDivF:
600   case kDivS:
601   case kDivU:
602     // A division is tricky, since 0/0, 0/c, c/0 all have
603     // specific outcomes for floating-point and integers.
604     // Thus, we need to traverse the full iteration space.
605     //
606     //  x/y|!y | y |
607     //  ---+---+---+
608     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
609     //   x |x/0|x/y|  INT: x/0=exception for any x
610     //
611     // TODO: for now we "fixed" this by only accepting x/c cases
612     //       during expression building, so that the conjunction
613     //       rules applies (viz. x/c = x*(1/c) as far as lattice
614     //       construction is concerned).
615     assert(!maybeZero(tensorExps[e].children.e1));
616     return takeConj(kind, // take binary conjunction
617                     buildLattices(tensorExps[e].children.e0, i),
618                     buildLattices(tensorExps[e].children.e1, i));
619   case kAddF:
620   case kAddC:
621   case kAddI:
622   case kSubF:
623   case kSubI:
624   case kOrI:
625   case kXorI:
626     // An additive operation needs to be performed
627     // for the disjunction of sparse iteration spaces.
628     //
629     //  x+y|!y | y |    x-y|!y | y |
630     //  ---+---+---+    ---+---+---+
631     //  !x | 0 | y |    !x | 0 |-y |
632     //   x | x |x+y|     x | x |x-y|
633     return takeDisj(kind, // take binary disjunction
634                     buildLattices(tensorExps[e].children.e0, i),
635                     buildLattices(tensorExps[e].children.e1, i));
636   case kShrS:
637   case kShrU:
638   case kShlI:
639     // A shift operation by an invariant amount (viz. tensor expressions
640     // can only occur at the left-hand-side of the operator) can be handled
641     // with the conjuction rule.
642     assert(isInvariant(tensorExps[e].children.e1));
643     return takeConj(kind, // take binary conjunction
644                     buildLattices(tensorExps[e].children.e0, i),
645                     buildLattices(tensorExps[e].children.e1, i));
646   case kBinary:
647     // A custom binary operation.
648     //
649     //  x op y|   !y    |       y      |
650     //  ------+---------+--------------+
651     //    !x  |  empty  |   right(y)   |
652     //     x  | left(x) | overlap(x,y) |
653     {
654       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
655       unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
656       BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
657       Region &leftRegion = binop.leftRegion();
658       Region &rightRegion = binop.rightRegion();
659       // Left Region.
660       Operation *leftYield = nullptr;
661       if (!leftRegion.empty()) {
662         Block &leftBlock = leftRegion.front();
663         leftYield = leftBlock.getTerminator();
664       }
665       // Right Region.
666       Operation *rightYield = nullptr;
667       if (!rightRegion.empty()) {
668         Block &rightBlock = rightRegion.front();
669         rightYield = rightBlock.getTerminator();
670       }
671       bool includeLeft = binop.left_identity() || !leftRegion.empty();
672       bool includeRight = binop.right_identity() || !rightRegion.empty();
673       return takeCombi(kBinary, child0, child1, binop, includeLeft,
674                        kBinaryBranch, leftYield, includeRight, kBinaryBranch,
675                        rightYield);
676     }
677   }
678   llvm_unreachable("unexpected expression kind");
679 }
680 
681 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
682   Operation *yield = op.region().front().getTerminator();
683   return buildTensorExp(op, yield->getOperand(0));
684 }
685 
686 /// Only returns false if we are certain this is a nonzero.
687 bool Merger::maybeZero(unsigned e) const {
688   if (tensorExps[e].kind == kInvariant) {
689     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
690       return c.value() == 0;
691     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
692       return c.value().isZero();
693   }
694   return true;
695 }
696 
697 bool Merger::isInvariant(unsigned e) const {
698   return tensorExps[e].kind == kInvariant;
699 }
700 
701 Type Merger::inferType(unsigned e, Value src) {
702   // Obtain the destination type from the cast node.
703   Type dtp = tensorExps[e].val.getType();
704   // Inspect source type. For vector types, apply the same
705   // vectorization to the destination type.
706   if (auto vtp = src.getType().dyn_cast<VectorType>())
707     return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
708   return dtp;
709 }
710 
711 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
712   if (auto arg = v.dyn_cast<BlockArgument>()) {
713     unsigned argN = arg.getArgNumber();
714     // Any argument of the generic op that is not marked as a scalar
715     // argument is considered a tensor, indexed by the implicit loop
716     // bounds. This includes rank-0 tensor arguments.
717     if (arg.getOwner()->getParentOp() == op) {
718       OpOperand *t = op.getInputAndOutputOperands()[argN];
719       if (!op.isScalar(t))
720         return addExp(kTensor, argN);
721       v = t->get(); // get scalar value
722     }
723     // Any other argument (marked as scalar argument for the generic op
724     // or belonging to an enveloping op) is considered invariant.
725     return addExp(kInvariant, v);
726   }
727   // Something defined outside is invariant.
728   Operation *def = v.getDefiningOp();
729   if (def->getBlock() != &op.region().front())
730     return addExp(kInvariant, v);
731   // Construct index operations.
732   if (def->getNumOperands() == 0) {
733     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
734       return addExp(kIndex, indexOp.dim());
735   }
736   // Construct unary operations if subexpression can be built.
737   if (def->getNumOperands() == 1) {
738     auto x = buildTensorExp(op, def->getOperand(0));
739     if (x.hasValue()) {
740       unsigned e = x.getValue();
741       if (isa<math::AbsOp>(def))
742         return addExp(kAbsF, e);
743       if (isa<math::CeilOp>(def))
744         return addExp(kCeilF, e);
745       if (isa<math::FloorOp>(def))
746         return addExp(kFloorF, e);
747       if (isa<math::SqrtOp>(def))
748         return addExp(kSqrtF, e);
749       if (isa<math::ExpM1Op>(def))
750         return addExp(kExpm1F, e);
751       if (isa<math::Log1pOp>(def))
752         return addExp(kLog1pF, e);
753       if (isa<math::SinOp>(def))
754         return addExp(kSinF, e);
755       if (isa<math::TanhOp>(def))
756         return addExp(kTanhF, e);
757       if (isa<arith::NegFOp>(def))
758         return addExp(kNegF, e); // no negi in std
759       if (isa<arith::TruncFOp>(def))
760         return addExp(kTruncF, e, v);
761       if (isa<arith::ExtFOp>(def))
762         return addExp(kExtF, e, v);
763       if (isa<arith::FPToSIOp>(def))
764         return addExp(kCastFS, e, v);
765       if (isa<arith::FPToUIOp>(def))
766         return addExp(kCastFU, e, v);
767       if (isa<arith::SIToFPOp>(def))
768         return addExp(kCastSF, e, v);
769       if (isa<arith::UIToFPOp>(def))
770         return addExp(kCastUF, e, v);
771       if (isa<arith::ExtSIOp>(def))
772         return addExp(kCastS, e, v);
773       if (isa<arith::ExtUIOp>(def))
774         return addExp(kCastU, e, v);
775       if (isa<arith::IndexCastOp>(def))
776         return addExp(kCastIdx, e, v);
777       if (isa<arith::TruncIOp>(def))
778         return addExp(kTruncI, e, v);
779       if (isa<arith::BitcastOp>(def))
780         return addExp(kBitCast, e, v);
781       if (isa<sparse_tensor::UnaryOp>(def))
782         return addExp(kUnary, e, Value(), def);
783     }
784   }
785   // Construct binary operations if subexpressions can be built.
786   // See buildLattices() for an explanation of rejecting certain
787   // division and shift operations
788   if (def->getNumOperands() == 2) {
789     auto x = buildTensorExp(op, def->getOperand(0));
790     auto y = buildTensorExp(op, def->getOperand(1));
791     if (x.hasValue() && y.hasValue()) {
792       unsigned e0 = x.getValue();
793       unsigned e1 = y.getValue();
794       if (isa<arith::MulFOp>(def))
795         return addExp(kMulF, e0, e1);
796       if (isa<complex::MulOp>(def))
797         return addExp(kMulC, e0, e1);
798       if (isa<arith::MulIOp>(def))
799         return addExp(kMulI, e0, e1);
800       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
801         return addExp(kDivF, e0, e1);
802       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
803         return addExp(kDivS, e0, e1);
804       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
805         return addExp(kDivU, e0, e1);
806       if (isa<arith::AddFOp>(def))
807         return addExp(kAddF, e0, e1);
808       if (isa<complex::AddOp>(def))
809         return addExp(kAddC, e0, e1);
810       if (isa<arith::AddIOp>(def))
811         return addExp(kAddI, e0, e1);
812       if (isa<arith::SubFOp>(def))
813         return addExp(kSubF, e0, e1);
814       if (isa<arith::SubIOp>(def))
815         return addExp(kSubI, e0, e1);
816       if (isa<arith::AndIOp>(def))
817         return addExp(kAndI, e0, e1);
818       if (isa<arith::OrIOp>(def))
819         return addExp(kOrI, e0, e1);
820       if (isa<arith::XOrIOp>(def))
821         return addExp(kXorI, e0, e1);
822       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
823         return addExp(kShrS, e0, e1);
824       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
825         return addExp(kShrU, e0, e1);
826       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
827         return addExp(kShlI, e0, e1);
828       if (isa<sparse_tensor::BinaryOp>(def))
829         return addExp(kBinary, e0, e1, Value(), def);
830     }
831   }
832   // Cannot build.
833   return None;
834 }
835 
836 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
837                            ValueRange vals) {
838   // Make a clone of overlap region.
839   Region tmpRegion;
840   BlockAndValueMapping mapper;
841   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
842   Block &clonedBlock = tmpRegion.front();
843   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
844   // Merge cloned block and return yield value.
845   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
846   rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
847   Value val = clonedYield.result();
848   rewriter.eraseOp(clonedYield);
849   rewriter.eraseOp(placeholder);
850   return val;
851 }
852 
853 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
854                                Operation *op, Value v0) {
855   if (!v0)
856     // Empty input value must be propagated.
857     return Value();
858   UnaryOp unop = cast<UnaryOp>(op);
859   Region &presentRegion = unop.presentRegion();
860   if (presentRegion.empty())
861     // Uninitialized Value() will be interpreted as missing data in the
862     // output.
863     return Value();
864   return insertYieldOp(rewriter, loc, presentRegion, {v0});
865 }
866 
867 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
868                                 Operation *op, Value v0, Value v1) {
869   if (!v0 || !v1)
870     // Empty input values must be propagated.
871     return Value();
872   BinaryOp binop = cast<BinaryOp>(op);
873   Region &overlapRegion = binop.overlapRegion();
874   if (overlapRegion.empty())
875     // Uninitialized Value() will be interpreted as missing data in the
876     // output.
877     return Value();
878   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
879 }
880 
881 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
882                        Value v0, Value v1) {
883   switch (tensorExps[e].kind) {
884   case kTensor:
885   case kInvariant:
886   case kIndex:
887     llvm_unreachable("unexpected non-op");
888   // Unary ops.
889   case kAbsF:
890     return rewriter.create<math::AbsOp>(loc, v0);
891   case kCeilF:
892     return rewriter.create<math::CeilOp>(loc, v0);
893   case kFloorF:
894     return rewriter.create<math::FloorOp>(loc, v0);
895   case kSqrtF:
896     return rewriter.create<math::SqrtOp>(loc, v0);
897   case kExpm1F:
898     return rewriter.create<math::ExpM1Op>(loc, v0);
899   case kLog1pF:
900     return rewriter.create<math::Log1pOp>(loc, v0);
901   case kSinF:
902     return rewriter.create<math::SinOp>(loc, v0);
903   case kTanhF:
904     return rewriter.create<math::TanhOp>(loc, v0);
905   case kNegF:
906     return rewriter.create<arith::NegFOp>(loc, v0);
907   case kNegI: // no negi in std
908     return rewriter.create<arith::SubIOp>(
909         loc,
910         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
911                                            rewriter.getZeroAttr(v0.getType())),
912         v0);
913   case kTruncF:
914     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
915   case kExtF:
916     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
917   case kCastFS:
918     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
919   case kCastFU:
920     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
921   case kCastSF:
922     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
923   case kCastUF:
924     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
925   case kCastS:
926     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
927   case kCastU:
928     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
929   case kCastIdx:
930     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
931   case kTruncI:
932     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
933   case kBitCast:
934     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
935   // Binary ops.
936   case kMulF:
937     return rewriter.create<arith::MulFOp>(loc, v0, v1);
938   case kMulC:
939     return rewriter.create<complex::MulOp>(loc, v0, v1);
940   case kMulI:
941     return rewriter.create<arith::MulIOp>(loc, v0, v1);
942   case kDivF:
943     return rewriter.create<arith::DivFOp>(loc, v0, v1);
944   case kDivS:
945     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
946   case kDivU:
947     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
948   case kAddF:
949     return rewriter.create<arith::AddFOp>(loc, v0, v1);
950   case kAddC:
951     return rewriter.create<complex::AddOp>(loc, v0, v1);
952   case kAddI:
953     return rewriter.create<arith::AddIOp>(loc, v0, v1);
954   case kSubF:
955     return rewriter.create<arith::SubFOp>(loc, v0, v1);
956   case kSubI:
957     return rewriter.create<arith::SubIOp>(loc, v0, v1);
958   case kAndI:
959     return rewriter.create<arith::AndIOp>(loc, v0, v1);
960   case kOrI:
961     return rewriter.create<arith::OrIOp>(loc, v0, v1);
962   case kXorI:
963     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
964   case kShrS:
965     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
966   case kShrU:
967     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
968   case kShlI:
969     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
970   // Semiring ops with custom logic.
971   case kBinaryBranch:
972     return insertYieldOp(rewriter, loc,
973                          *tensorExps[e].op->getBlock()->getParent(), {v0});
974   case kUnary:
975     return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
976   case kBinary:
977     return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
978   }
979   llvm_unreachable("unexpected expression kind in build");
980 }
981 
982 } // namespace sparse_tensor
983 } // namespace mlir
984