1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 
18 namespace mlir {
19 namespace sparse_tensor {
20 
21 //===----------------------------------------------------------------------===//
22 // Constructors.
23 //===----------------------------------------------------------------------===//
24 
25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
26     : kind(k), val(v), op(o) {
27   switch (kind) {
28   case kTensor:
29     assert(x != -1u && y == -1u && !v && !o);
30     tensor = x;
31     break;
32   case kInvariant:
33     assert(x == -1u && y == -1u && v && !o);
34     break;
35   case kIndex:
36     assert(x != -1u && y == -1u && !v && !o);
37     index = x;
38     break;
39   case kAbsF:
40   case kCeilF:
41   case kFloorF:
42   case kSqrtF:
43   case kExpm1F:
44   case kLog1pF:
45   case kSinF:
46   case kTanhF:
47   case kNegF:
48   case kNegI:
49   case kCIm:
50   case kCRe:
51     assert(x != -1u && y == -1u && !v && !o);
52     children.e0 = x;
53     children.e1 = y;
54     break;
55   case kTruncF:
56   case kExtF:
57   case kCastFS:
58   case kCastFU:
59   case kCastSF:
60   case kCastUF:
61   case kCastS:
62   case kCastU:
63   case kCastIdx:
64   case kTruncI:
65   case kBitCast:
66     assert(x != -1u && y == -1u && v && !o);
67     children.e0 = x;
68     children.e1 = y;
69     break;
70   case kBinaryBranch:
71     assert(x != -1u && y == -1u && !v && o);
72     children.e0 = x;
73     children.e1 = y;
74     break;
75   case kUnary:
76     // No assertion on y can be made, as the branching paths involve both
77     // a unary (mapSet) and binary (takeDisj) pathway.
78     assert(x != -1u && !v && o);
79     children.e0 = x;
80     children.e1 = y;
81     break;
82   case kBinary:
83     assert(x != -1u && y != -1u && !v && o);
84     children.e0 = x;
85     children.e1 = y;
86     break;
87   default:
88     assert(x != -1u && y != -1u && !v && !o);
89     children.e0 = x;
90     children.e1 = y;
91     break;
92   }
93 }
94 
95 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
96     : bits(n, false), simple(), exp(e) {
97   bits.set(b);
98 }
99 
100 LatPoint::LatPoint(const BitVector &b, unsigned e)
101     : bits(b), simple(), exp(e) {}
102 
103 //===----------------------------------------------------------------------===//
104 // Lattice methods.
105 //===----------------------------------------------------------------------===//
106 
107 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
108                         Operation *op) {
109   unsigned e = tensorExps.size();
110   tensorExps.push_back(TensorExp(k, e0, e1, v, op));
111   return e;
112 }
113 
114 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
115   assert(t < numTensors && i < numLoops);
116   unsigned p = latPoints.size();
117   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
118   return p;
119 }
120 
121 unsigned Merger::addSet() {
122   unsigned s = latSets.size();
123   latSets.emplace_back(SmallVector<unsigned, 16>());
124   return s;
125 }
126 
127 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
128                               Operation *op) {
129   unsigned p = latPoints.size();
130   BitVector nb = BitVector(latPoints[p0].bits);
131   nb |= latPoints[p1].bits;
132   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
133   latPoints.push_back(LatPoint(nb, e));
134   return p;
135 }
136 
137 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
138   unsigned s = addSet();
139   for (unsigned p0 : latSets[s0])
140     for (unsigned p1 : latSets[s1])
141       latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
142   return s;
143 }
144 
145 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
146   unsigned s = takeConj(kind, s0, s1, op);
147   // Followed by all in s0.
148   for (unsigned p : latSets[s0])
149     latSets[s].push_back(p);
150   // Map binary 0-y to unary -y.
151   // TODO: move this if-else logic into buildLattices
152   if (kind == kSubF)
153     s1 = mapSet(kNegF, s1);
154   else if (kind == kSubI)
155     s1 = mapSet(kNegI, s1);
156   // Followed by all in s1.
157   for (unsigned p : latSets[s1])
158     latSets[s].push_back(p);
159   return s;
160 }
161 
162 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
163                            bool includeLeft, Kind ltrans, Operation *opleft,
164                            bool includeRight, Kind rtrans, Operation *opright) {
165   unsigned s = takeConj(kind, s0, s1, orig);
166   // Left Region.
167   if (includeLeft) {
168     if (opleft)
169       s0 = mapSet(ltrans, s0, Value(), opleft);
170     for (unsigned p : latSets[s0])
171       latSets[s].push_back(p);
172   }
173   // Right Region.
174   if (includeRight) {
175     if (opright)
176       s1 = mapSet(rtrans, s1, Value(), opright);
177     for (unsigned p : latSets[s1])
178       latSets[s].push_back(p);
179   }
180   return s;
181 }
182 
183 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
184   assert(kAbsF <= kind && kind <= kUnary);
185   unsigned s = addSet();
186   for (unsigned p : latSets[s0]) {
187     unsigned e = addExp(kind, latPoints[p].exp, v, op);
188     latPoints.push_back(LatPoint(latPoints[p].bits, e));
189     latSets[s].push_back(latPoints.size() - 1);
190   }
191   return s;
192 }
193 
194 unsigned Merger::optimizeSet(unsigned s0) {
195   unsigned s = addSet();
196   assert(!latSets[s0].empty());
197   unsigned p0 = latSets[s0][0];
198   for (unsigned p1 : latSets[s0]) {
199     bool add = true;
200     if (p0 != p1) {
201       // Is this a straightforward copy?
202       unsigned e = latPoints[p1].exp;
203       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
204         continue;
205       // Conjunction already covered?
206       for (unsigned p2 : latSets[s]) {
207         assert(!latGT(p1, p2)); // Lj => Li would be bad
208         if (onlyDenseDiff(p2, p1)) {
209           add = false;
210           break;
211         }
212       }
213       assert(!add || latGT(p0, p1));
214     }
215     if (add)
216       latSets[s].push_back(p1);
217   }
218   for (unsigned p : latSets[s])
219     latPoints[p].simple = simplifyCond(s, p);
220   return s;
221 }
222 
223 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
224   // First determine if this lattice point is a *singleton*, i.e.,
225   // the last point in a lattice, no other is less than this one.
226   bool isSingleton = true;
227   for (unsigned p1 : latSets[s0]) {
228     if (p0 != p1 && latGT(p0, p1)) {
229       isSingleton = false;
230       break;
231     }
232   }
233   // Now apply the two basic rules.
234   BitVector simple = latPoints[p0].bits;
235   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
236   for (unsigned b = 0, be = simple.size(); b < be; b++) {
237     if (simple[b] && !isDim(b, kSparse)) {
238       if (reset)
239         simple.reset(b);
240       reset = true;
241     }
242   }
243   return simple;
244 }
245 
246 bool Merger::latGT(unsigned i, unsigned j) const {
247   const BitVector &bitsi = latPoints[i].bits;
248   const BitVector &bitsj = latPoints[j].bits;
249   assert(bitsi.size() == bitsj.size());
250   if (bitsi.count() > bitsj.count()) {
251     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
252       if (bitsj[b] && !bitsi[b])
253         return false;
254     return true;
255   }
256   return false;
257 }
258 
259 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
260   BitVector tmp = latPoints[j].bits;
261   tmp ^= latPoints[i].bits;
262   return !hasAnyDimOf(tmp, kSparse);
263 }
264 
265 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
266   for (unsigned b = 0, be = bits.size(); b < be; b++)
267     if (bits[b] && isDim(b, d))
268       return true;
269   return false;
270 }
271 
272 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
273   switch (tensorExps[e].kind) {
274   case kTensor:
275     return tensorExps[e].tensor == t;
276   case kAbsF:
277   case kCeilF:
278   case kFloorF:
279   case kSqrtF:
280   case kExpm1F:
281   case kLog1pF:
282   case kSinF:
283   case kTanhF:
284   case kNegF:
285   case kNegI:
286   case kTruncF:
287   case kExtF:
288   case kCastFS:
289   case kCastFU:
290   case kCastSF:
291   case kCastUF:
292   case kCastS:
293   case kCastU:
294   case kCastIdx:
295   case kTruncI:
296   case kCIm:
297   case kCRe:
298   case kBitCast:
299     return isSingleCondition(t, tensorExps[e].children.e0);
300   case kDivF: // note: x / c only
301   case kDivS:
302   case kDivU:
303     assert(!maybeZero(tensorExps[e].children.e1));
304     return isSingleCondition(t, tensorExps[e].children.e0);
305   case kShrS: // note: x >> inv only
306   case kShrU:
307   case kShlI:
308     assert(isInvariant(tensorExps[e].children.e1));
309     return isSingleCondition(t, tensorExps[e].children.e0);
310   case kMulF:
311   case kMulC:
312   case kMulI:
313   case kAndI:
314     if (isSingleCondition(t, tensorExps[e].children.e0))
315       return isSingleCondition(t, tensorExps[e].children.e1) ||
316              isInvariant(tensorExps[e].children.e1);
317     if (isSingleCondition(t, tensorExps[e].children.e1))
318       return isInvariant(tensorExps[e].children.e0);
319     return false;
320   case kAddF:
321   case kAddC:
322   case kAddI:
323     return isSingleCondition(t, tensorExps[e].children.e0) &&
324            isSingleCondition(t, tensorExps[e].children.e1);
325   default:
326     return false;
327   }
328 }
329 
330 #ifndef NDEBUG
331 
332 //===----------------------------------------------------------------------===//
333 // Print methods (for debugging).
334 //===----------------------------------------------------------------------===//
335 
336 static const char *kindToOpSymbol(Kind kind) {
337   switch (kind) {
338   case kTensor:
339     return "tensor";
340   case kInvariant:
341     return "invariant";
342   case kIndex:
343     return "index";
344   case kAbsF:
345     return "abs";
346   case kCeilF:
347     return "ceil";
348   case kFloorF:
349     return "floor";
350   case kSqrtF:
351     return "sqrt";
352   case kExpm1F:
353     return "expm1";
354   case kLog1pF:
355     return "log1p";
356   case kSinF:
357     return "sin";
358   case kTanhF:
359     return "tanh";
360   case kNegF:
361     return "-";
362   case kNegI:
363     return "-";
364   case kTruncF:
365   case kExtF:
366   case kCastFS:
367   case kCastFU:
368   case kCastSF:
369   case kCastUF:
370   case kCastS:
371   case kCastU:
372   case kCastIdx:
373   case kTruncI:
374   case kCIm:
375     return "complex.im";
376   case kCRe:
377     return "complex.re";
378   case kBitCast:
379     return "cast";
380   case kBinaryBranch:
381     return "binary_branch";
382   case kUnary:
383     return "unary";
384   case kMulF:
385   case kMulC:
386   case kMulI:
387     return "*";
388   case kDivF:
389   case kDivS:
390   case kDivU:
391     return "/";
392   case kAddF:
393   case kAddC:
394   case kAddI:
395     return "+";
396   case kSubF:
397   case kSubI:
398     return "-";
399   case kAndI:
400     return "&";
401   case kOrI:
402     return "|";
403   case kXorI:
404     return "^";
405   case kShrS:
406     return "a>>";
407   case kShrU:
408     return ">>";
409   case kShlI:
410     return "<<";
411   case kBinary:
412     return "binary";
413   }
414   llvm_unreachable("unexpected kind for symbol");
415 }
416 
417 void Merger::dumpExp(unsigned e) const {
418   switch (tensorExps[e].kind) {
419   case kTensor:
420     if (tensorExps[e].tensor == syntheticTensor)
421       llvm::dbgs() << "synthetic_";
422     else if (tensorExps[e].tensor == outTensor)
423       llvm::dbgs() << "output_";
424     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
425     break;
426   case kInvariant:
427     llvm::dbgs() << "invariant";
428     break;
429   case kIndex:
430     llvm::dbgs() << "index_" << tensorExps[e].index;
431     break;
432   case kAbsF:
433   case kCeilF:
434   case kFloorF:
435   case kSqrtF:
436   case kExpm1F:
437   case kLog1pF:
438   case kSinF:
439   case kTanhF:
440   case kNegF:
441   case kNegI:
442   case kTruncF:
443   case kExtF:
444   case kCastFS:
445   case kCastFU:
446   case kCastSF:
447   case kCastUF:
448   case kCastS:
449   case kCastU:
450   case kCastIdx:
451   case kTruncI:
452   case kBitCast:
453     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
454     dumpExp(tensorExps[e].children.e0);
455     break;
456   default:
457     llvm::dbgs() << "(";
458     dumpExp(tensorExps[e].children.e0);
459     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
460     dumpExp(tensorExps[e].children.e1);
461     llvm::dbgs() << ")";
462   }
463 }
464 
465 void Merger::dumpLat(unsigned p) const {
466   llvm::dbgs() << "lat(";
467   dumpBits(latPoints[p].bits);
468   llvm::dbgs() << " :";
469   dumpBits(latPoints[p].simple);
470   llvm::dbgs() << " : ";
471   dumpExp(latPoints[p].exp);
472   llvm::dbgs() << " )\n";
473 }
474 
475 void Merger::dumpSet(unsigned s) const {
476   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
477   for (unsigned p : latSets[s]) {
478     llvm::dbgs() << "  ";
479     dumpLat(p);
480   }
481   llvm::dbgs() << "}\n";
482 }
483 
484 void Merger::dumpBits(const BitVector &bits) const {
485   for (unsigned b = 0, be = bits.size(); b < be; b++) {
486     if (bits[b]) {
487       unsigned t = tensor(b);
488       unsigned i = index(b);
489       llvm::dbgs() << " i_" << t << "_" << i << "_";
490       switch (dims[t][i]) {
491       case kSparse:
492         llvm::dbgs() << "S";
493         break;
494       case kDense:
495         llvm::dbgs() << "D";
496         break;
497       case kSingle:
498         llvm::dbgs() << "T";
499         break;
500       case kUndef:
501         llvm::dbgs() << "U";
502         break;
503       }
504     }
505   }
506 }
507 
508 #endif // NDEBUG
509 
510 //===----------------------------------------------------------------------===//
511 // Builder methods.
512 //===----------------------------------------------------------------------===//
513 
514 unsigned Merger::buildLattices(unsigned e, unsigned i) {
515   Kind kind = tensorExps[e].kind;
516   switch (kind) {
517   case kTensor:
518   case kInvariant:
519   case kIndex: {
520     // Either the index is really used in the tensor expression, or it is
521     // set to the undefined index in that dimension. An invariant expression,
522     // a proper index value, and a truly dynamic sparse output tensor are set
523     // to a synthetic tensor with undefined indices only to ensure the
524     // iteration space is not skipped as a result of their contents.
525     unsigned s = addSet();
526     unsigned t = syntheticTensor;
527     if (kind == kTensor) {
528       t = tensorExps[e].tensor;
529       if (hasSparseOut && t == outTensor)
530         t = syntheticTensor;
531     }
532     latSets[s].push_back(addLat(t, i, e));
533     return s;
534   }
535   case kAbsF:
536   case kCeilF:
537   case kCIm:
538   case kCRe:
539   case kFloorF:
540   case kSqrtF:
541   case kExpm1F:
542   case kLog1pF:
543   case kSinF:
544   case kTanhF:
545   case kNegF:
546   case kNegI:
547   case kTruncF:
548   case kExtF:
549   case kCastFS:
550   case kCastFU:
551   case kCastSF:
552   case kCastUF:
553   case kCastS:
554   case kCastU:
555   case kCastIdx:
556   case kTruncI:
557   case kBitCast:
558     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
559     // lattice set of the operand through the operator into a new set.
560     //
561     //  -y|!y | y |
562     //  --+---+---+
563     //    | 0 |-y |
564     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
565                   tensorExps[e].val);
566   case kBinaryBranch:
567     // The left or right half of a binary operation which has already
568     // been split into separate operations for each region.
569     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
570                   tensorExps[e].op);
571   case kUnary:
572     // A custom unary operation.
573     //
574     //  op y|    !y    |     y      |
575     //  ----+----------+------------+
576     //      | absent() | present(y) |
577     {
578       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
579       UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
580       Region &absentRegion = unop.absentRegion();
581 
582       if (absentRegion.empty()) {
583         // Simple mapping over existing values.
584         return mapSet(kind, child0, Value(), unop);
585       } // Use a disjunction with `unop` on the left and the absent value as an
586       // invariant on the right.
587       Block &absentBlock = absentRegion.front();
588       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
589       Value absentVal = absentYield.result();
590       unsigned rhs = addExp(kInvariant, absentVal);
591       return takeDisj(kind, child0, buildLattices(rhs, i), unop);
592     }
593   case kMulF:
594   case kMulC:
595   case kMulI:
596   case kAndI:
597     // A multiplicative operation only needs to be performed
598     // for the conjunction of sparse iteration spaces.
599     //
600     //  x*y|!y | y |
601     //  ---+---+---+
602     //  !x | 0 | 0 |
603     //   x | 0 |x*y|
604     //
605     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
606     return takeConj(kind, // take binary conjunction
607                     buildLattices(tensorExps[e].children.e0, i),
608                     buildLattices(tensorExps[e].children.e1, i));
609   case kDivF:
610   case kDivS:
611   case kDivU:
612     // A division is tricky, since 0/0, 0/c, c/0 all have
613     // specific outcomes for floating-point and integers.
614     // Thus, we need to traverse the full iteration space.
615     //
616     //  x/y|!y | y |
617     //  ---+---+---+
618     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
619     //   x |x/0|x/y|  INT: x/0=exception for any x
620     //
621     // TODO: for now we "fixed" this by only accepting x/c cases
622     //       during expression building, so that the conjunction
623     //       rules applies (viz. x/c = x*(1/c) as far as lattice
624     //       construction is concerned).
625     assert(!maybeZero(tensorExps[e].children.e1));
626     return takeConj(kind, // take binary conjunction
627                     buildLattices(tensorExps[e].children.e0, i),
628                     buildLattices(tensorExps[e].children.e1, i));
629   case kAddF:
630   case kAddC:
631   case kAddI:
632   case kSubF:
633   case kSubI:
634   case kOrI:
635   case kXorI:
636     // An additive operation needs to be performed
637     // for the disjunction of sparse iteration spaces.
638     //
639     //  x+y|!y | y |    x-y|!y | y |
640     //  ---+---+---+    ---+---+---+
641     //  !x | 0 | y |    !x | 0 |-y |
642     //   x | x |x+y|     x | x |x-y|
643     return takeDisj(kind, // take binary disjunction
644                     buildLattices(tensorExps[e].children.e0, i),
645                     buildLattices(tensorExps[e].children.e1, i));
646   case kShrS:
647   case kShrU:
648   case kShlI:
649     // A shift operation by an invariant amount (viz. tensor expressions
650     // can only occur at the left-hand-side of the operator) can be handled
651     // with the conjuction rule.
652     assert(isInvariant(tensorExps[e].children.e1));
653     return takeConj(kind, // take binary conjunction
654                     buildLattices(tensorExps[e].children.e0, i),
655                     buildLattices(tensorExps[e].children.e1, i));
656   case kBinary:
657     // A custom binary operation.
658     //
659     //  x op y|   !y    |       y      |
660     //  ------+---------+--------------+
661     //    !x  |  empty  |   right(y)   |
662     //     x  | left(x) | overlap(x,y) |
663     {
664       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
665       unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
666       BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
667       Region &leftRegion = binop.leftRegion();
668       Region &rightRegion = binop.rightRegion();
669       // Left Region.
670       Operation *leftYield = nullptr;
671       if (!leftRegion.empty()) {
672         Block &leftBlock = leftRegion.front();
673         leftYield = leftBlock.getTerminator();
674       }
675       // Right Region.
676       Operation *rightYield = nullptr;
677       if (!rightRegion.empty()) {
678         Block &rightBlock = rightRegion.front();
679         rightYield = rightBlock.getTerminator();
680       }
681       bool includeLeft = binop.left_identity() || !leftRegion.empty();
682       bool includeRight = binop.right_identity() || !rightRegion.empty();
683       return takeCombi(kBinary, child0, child1, binop, includeLeft,
684                        kBinaryBranch, leftYield, includeRight, kBinaryBranch,
685                        rightYield);
686     }
687   }
688   llvm_unreachable("unexpected expression kind");
689 }
690 
691 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
692   Operation *yield = op.region().front().getTerminator();
693   return buildTensorExp(op, yield->getOperand(0));
694 }
695 
696 /// Only returns false if we are certain this is a nonzero.
697 bool Merger::maybeZero(unsigned e) const {
698   if (tensorExps[e].kind == kInvariant) {
699     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
700       return c.value() == 0;
701     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
702       return c.value().isZero();
703   }
704   return true;
705 }
706 
707 bool Merger::isInvariant(unsigned e) const {
708   return tensorExps[e].kind == kInvariant;
709 }
710 
711 Type Merger::inferType(unsigned e, Value src) {
712   // Obtain the destination type from the cast node.
713   Type dtp = tensorExps[e].val.getType();
714   // Inspect source type. For vector types, apply the same
715   // vectorization to the destination type.
716   if (auto vtp = src.getType().dyn_cast<VectorType>())
717     return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
718   return dtp;
719 }
720 
721 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
722   if (auto arg = v.dyn_cast<BlockArgument>()) {
723     unsigned argN = arg.getArgNumber();
724     // Any argument of the generic op that is not marked as a scalar
725     // argument is considered a tensor, indexed by the implicit loop
726     // bounds. This includes rank-0 tensor arguments.
727     if (arg.getOwner()->getParentOp() == op) {
728       OpOperand *t = op.getInputAndOutputOperands()[argN];
729       if (!op.isScalar(t))
730         return addExp(kTensor, argN);
731       v = t->get(); // get scalar value
732     }
733     // Any other argument (marked as scalar argument for the generic op
734     // or belonging to an enveloping op) is considered invariant.
735     return addExp(kInvariant, v);
736   }
737   // Something defined outside is invariant.
738   Operation *def = v.getDefiningOp();
739   if (def->getBlock() != &op.region().front())
740     return addExp(kInvariant, v);
741   // Construct index operations.
742   if (def->getNumOperands() == 0) {
743     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
744       return addExp(kIndex, indexOp.dim());
745   }
746   // Construct unary operations if subexpression can be built.
747   if (def->getNumOperands() == 1) {
748     auto x = buildTensorExp(op, def->getOperand(0));
749     if (x.hasValue()) {
750       unsigned e = x.getValue();
751       if (isa<math::AbsOp>(def))
752         return addExp(kAbsF, e);
753       if (isa<math::CeilOp>(def))
754         return addExp(kCeilF, e);
755       if (isa<math::FloorOp>(def))
756         return addExp(kFloorF, e);
757       if (isa<math::SqrtOp>(def))
758         return addExp(kSqrtF, e);
759       if (isa<math::ExpM1Op>(def))
760         return addExp(kExpm1F, e);
761       if (isa<math::Log1pOp>(def))
762         return addExp(kLog1pF, e);
763       if (isa<math::SinOp>(def))
764         return addExp(kSinF, e);
765       if (isa<math::TanhOp>(def))
766         return addExp(kTanhF, e);
767       if (isa<arith::NegFOp>(def))
768         return addExp(kNegF, e); // no negi in std
769       if (isa<arith::TruncFOp>(def))
770         return addExp(kTruncF, e, v);
771       if (isa<arith::ExtFOp>(def))
772         return addExp(kExtF, e, v);
773       if (isa<arith::FPToSIOp>(def))
774         return addExp(kCastFS, e, v);
775       if (isa<arith::FPToUIOp>(def))
776         return addExp(kCastFU, e, v);
777       if (isa<arith::SIToFPOp>(def))
778         return addExp(kCastSF, e, v);
779       if (isa<arith::UIToFPOp>(def))
780         return addExp(kCastUF, e, v);
781       if (isa<arith::ExtSIOp>(def))
782         return addExp(kCastS, e, v);
783       if (isa<arith::ExtUIOp>(def))
784         return addExp(kCastU, e, v);
785       if (isa<arith::IndexCastOp>(def))
786         return addExp(kCastIdx, e, v);
787       if (isa<arith::TruncIOp>(def))
788         return addExp(kTruncI, e, v);
789       if (isa<complex::ImOp>(def))
790         return addExp(kCIm, e);
791       if (isa<complex::ReOp>(def))
792         return addExp(kCRe, e);
793       if (isa<arith::BitcastOp>(def))
794         return addExp(kBitCast, e, v);
795       if (isa<sparse_tensor::UnaryOp>(def))
796         return addExp(kUnary, e, Value(), def);
797     }
798   }
799   // Construct binary operations if subexpressions can be built.
800   // See buildLattices() for an explanation of rejecting certain
801   // division and shift operations
802   if (def->getNumOperands() == 2) {
803     auto x = buildTensorExp(op, def->getOperand(0));
804     auto y = buildTensorExp(op, def->getOperand(1));
805     if (x.hasValue() && y.hasValue()) {
806       unsigned e0 = x.getValue();
807       unsigned e1 = y.getValue();
808       if (isa<arith::MulFOp>(def))
809         return addExp(kMulF, e0, e1);
810       if (isa<complex::MulOp>(def))
811         return addExp(kMulC, e0, e1);
812       if (isa<arith::MulIOp>(def))
813         return addExp(kMulI, e0, e1);
814       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
815         return addExp(kDivF, e0, e1);
816       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
817         return addExp(kDivS, e0, e1);
818       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
819         return addExp(kDivU, e0, e1);
820       if (isa<arith::AddFOp>(def))
821         return addExp(kAddF, e0, e1);
822       if (isa<complex::AddOp>(def))
823         return addExp(kAddC, e0, e1);
824       if (isa<arith::AddIOp>(def))
825         return addExp(kAddI, e0, e1);
826       if (isa<arith::SubFOp>(def))
827         return addExp(kSubF, e0, e1);
828       if (isa<arith::SubIOp>(def))
829         return addExp(kSubI, e0, e1);
830       if (isa<arith::AndIOp>(def))
831         return addExp(kAndI, e0, e1);
832       if (isa<arith::OrIOp>(def))
833         return addExp(kOrI, e0, e1);
834       if (isa<arith::XOrIOp>(def))
835         return addExp(kXorI, e0, e1);
836       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
837         return addExp(kShrS, e0, e1);
838       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
839         return addExp(kShrU, e0, e1);
840       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
841         return addExp(kShlI, e0, e1);
842       if (isa<sparse_tensor::BinaryOp>(def))
843         return addExp(kBinary, e0, e1, Value(), def);
844     }
845   }
846   // Cannot build.
847   return None;
848 }
849 
850 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
851                            ValueRange vals) {
852   // Make a clone of overlap region.
853   Region tmpRegion;
854   BlockAndValueMapping mapper;
855   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
856   Block &clonedBlock = tmpRegion.front();
857   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
858   // Merge cloned block and return yield value.
859   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
860   rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
861   Value val = clonedYield.result();
862   rewriter.eraseOp(clonedYield);
863   rewriter.eraseOp(placeholder);
864   return val;
865 }
866 
867 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
868                                Operation *op, Value v0) {
869   if (!v0)
870     // Empty input value must be propagated.
871     return Value();
872   UnaryOp unop = cast<UnaryOp>(op);
873   Region &presentRegion = unop.presentRegion();
874   if (presentRegion.empty())
875     // Uninitialized Value() will be interpreted as missing data in the
876     // output.
877     return Value();
878   return insertYieldOp(rewriter, loc, presentRegion, {v0});
879 }
880 
881 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
882                                 Operation *op, Value v0, Value v1) {
883   if (!v0 || !v1)
884     // Empty input values must be propagated.
885     return Value();
886   BinaryOp binop = cast<BinaryOp>(op);
887   Region &overlapRegion = binop.overlapRegion();
888   if (overlapRegion.empty())
889     // Uninitialized Value() will be interpreted as missing data in the
890     // output.
891     return Value();
892   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
893 }
894 
895 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
896                        Value v0, Value v1) {
897   switch (tensorExps[e].kind) {
898   case kTensor:
899   case kInvariant:
900   case kIndex:
901     llvm_unreachable("unexpected non-op");
902   // Unary ops.
903   case kAbsF:
904     return rewriter.create<math::AbsOp>(loc, v0);
905   case kCeilF:
906     return rewriter.create<math::CeilOp>(loc, v0);
907   case kFloorF:
908     return rewriter.create<math::FloorOp>(loc, v0);
909   case kSqrtF:
910     return rewriter.create<math::SqrtOp>(loc, v0);
911   case kExpm1F:
912     return rewriter.create<math::ExpM1Op>(loc, v0);
913   case kLog1pF:
914     return rewriter.create<math::Log1pOp>(loc, v0);
915   case kSinF:
916     return rewriter.create<math::SinOp>(loc, v0);
917   case kTanhF:
918     return rewriter.create<math::TanhOp>(loc, v0);
919   case kNegF:
920     return rewriter.create<arith::NegFOp>(loc, v0);
921   case kNegI: // no negi in std
922     return rewriter.create<arith::SubIOp>(
923         loc,
924         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
925                                            rewriter.getZeroAttr(v0.getType())),
926         v0);
927   case kTruncF:
928     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
929   case kExtF:
930     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
931   case kCastFS:
932     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
933   case kCastFU:
934     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
935   case kCastSF:
936     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
937   case kCastUF:
938     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
939   case kCastS:
940     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
941   case kCastU:
942     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
943   case kCastIdx:
944     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
945   case kTruncI:
946     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
947   case kCIm:
948   case kCRe: {
949     auto type = v0.getType().template cast<ComplexType>();
950     auto eltType = type.getElementType().template cast<FloatType>();
951     if (tensorExps[e].kind == kCIm)
952       return rewriter.create<complex::ImOp>(loc, eltType, v0);
953 
954     return rewriter.create<complex::ReOp>(loc, eltType, v0);
955   }
956   case kBitCast:
957     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
958   // Binary ops.
959   case kMulF:
960     return rewriter.create<arith::MulFOp>(loc, v0, v1);
961   case kMulC:
962     return rewriter.create<complex::MulOp>(loc, v0, v1);
963   case kMulI:
964     return rewriter.create<arith::MulIOp>(loc, v0, v1);
965   case kDivF:
966     return rewriter.create<arith::DivFOp>(loc, v0, v1);
967   case kDivS:
968     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
969   case kDivU:
970     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
971   case kAddF:
972     return rewriter.create<arith::AddFOp>(loc, v0, v1);
973   case kAddC:
974     return rewriter.create<complex::AddOp>(loc, v0, v1);
975   case kAddI:
976     return rewriter.create<arith::AddIOp>(loc, v0, v1);
977   case kSubF:
978     return rewriter.create<arith::SubFOp>(loc, v0, v1);
979   case kSubI:
980     return rewriter.create<arith::SubIOp>(loc, v0, v1);
981   case kAndI:
982     return rewriter.create<arith::AndIOp>(loc, v0, v1);
983   case kOrI:
984     return rewriter.create<arith::OrIOp>(loc, v0, v1);
985   case kXorI:
986     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
987   case kShrS:
988     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
989   case kShrU:
990     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
991   case kShlI:
992     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
993   // Semiring ops with custom logic.
994   case kBinaryBranch:
995     return insertYieldOp(rewriter, loc,
996                          *tensorExps[e].op->getBlock()->getParent(), {v0});
997   case kUnary:
998     return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
999   case kBinary:
1000     return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
1001   }
1002   llvm_unreachable("unexpected expression kind in build");
1003 }
1004 
1005 } // namespace sparse_tensor
1006 } // namespace mlir
1007