1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 
18 namespace mlir {
19 namespace sparse_tensor {
20 
21 //===----------------------------------------------------------------------===//
22 // Constructors.
23 //===----------------------------------------------------------------------===//
24 
25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
26     : kind(k), val(v), op(o) {
27   switch (kind) {
28   case kTensor:
29     assert(x != -1u && y == -1u && !v && !o);
30     tensor = x;
31     break;
32   case kInvariant:
33     assert(x == -1u && y == -1u && v && !o);
34     break;
35   case kIndex:
36     assert(x != -1u && y == -1u && !v && !o);
37     index = x;
38     break;
39   case kAbsF:
40   case kAbsC:
41   case kCeilF:
42   case kFloorF:
43   case kSqrtF:
44   case kExpm1F:
45   case kLog1pF:
46   case kLog1pC:
47   case kSinF:
48   case kSinC:
49   case kTanhF:
50   case kNegF:
51   case kNegC:
52   case kNegI:
53   case kCIm:
54   case kCRe:
55     assert(x != -1u && y == -1u && !v && !o);
56     children.e0 = x;
57     children.e1 = y;
58     break;
59   case kTruncF:
60   case kExtF:
61   case kCastFS:
62   case kCastFU:
63   case kCastSF:
64   case kCastUF:
65   case kCastS:
66   case kCastU:
67   case kCastIdx:
68   case kTruncI:
69   case kBitCast:
70     assert(x != -1u && y == -1u && v && !o);
71     children.e0 = x;
72     children.e1 = y;
73     break;
74   case kBinaryBranch:
75     assert(x != -1u && y == -1u && !v && o);
76     children.e0 = x;
77     children.e1 = y;
78     break;
79   case kUnary:
80     // No assertion on y can be made, as the branching paths involve both
81     // a unary (mapSet) and binary (takeDisj) pathway.
82     assert(x != -1u && !v && o);
83     children.e0 = x;
84     children.e1 = y;
85     break;
86   case kBinary:
87     assert(x != -1u && y != -1u && !v && o);
88     children.e0 = x;
89     children.e1 = y;
90     break;
91   default:
92     assert(x != -1u && y != -1u && !v && !o);
93     children.e0 = x;
94     children.e1 = y;
95     break;
96   }
97 }
98 
99 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
100     : bits(n, false), simple(), exp(e) {
101   bits.set(b);
102 }
103 
104 LatPoint::LatPoint(const BitVector &b, unsigned e)
105     : bits(b), simple(), exp(e) {}
106 
107 //===----------------------------------------------------------------------===//
108 // Lattice methods.
109 //===----------------------------------------------------------------------===//
110 
111 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
112                         Operation *op) {
113   unsigned e = tensorExps.size();
114   tensorExps.push_back(TensorExp(k, e0, e1, v, op));
115   return e;
116 }
117 
118 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
119   assert(t < numTensors && i < numLoops);
120   unsigned p = latPoints.size();
121   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
122   return p;
123 }
124 
125 unsigned Merger::addSet() {
126   unsigned s = latSets.size();
127   latSets.emplace_back(SmallVector<unsigned, 16>());
128   return s;
129 }
130 
131 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
132                               Operation *op) {
133   unsigned p = latPoints.size();
134   BitVector nb = BitVector(latPoints[p0].bits);
135   nb |= latPoints[p1].bits;
136   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
137   latPoints.push_back(LatPoint(nb, e));
138   return p;
139 }
140 
141 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
142   unsigned s = addSet();
143   for (unsigned p0 : latSets[s0])
144     for (unsigned p1 : latSets[s1])
145       latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
146   return s;
147 }
148 
149 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
150   unsigned s = takeConj(kind, s0, s1, op);
151   // Followed by all in s0.
152   for (unsigned p : latSets[s0])
153     latSets[s].push_back(p);
154   // Map binary 0-y to unary -y.
155   // TODO: move this if-else logic into buildLattices
156   if (kind == kSubF)
157     s1 = mapSet(kNegF, s1);
158   else if (kind == kSubC)
159     s1 = mapSet(kNegC, s1);
160   else if (kind == kSubI)
161     s1 = mapSet(kNegI, s1);
162   // Followed by all in s1.
163   for (unsigned p : latSets[s1])
164     latSets[s].push_back(p);
165   return s;
166 }
167 
168 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
169                            bool includeLeft, Kind ltrans, Operation *opleft,
170                            bool includeRight, Kind rtrans, Operation *opright) {
171   unsigned s = takeConj(kind, s0, s1, orig);
172   // Left Region.
173   if (includeLeft) {
174     if (opleft)
175       s0 = mapSet(ltrans, s0, Value(), opleft);
176     for (unsigned p : latSets[s0])
177       latSets[s].push_back(p);
178   }
179   // Right Region.
180   if (includeRight) {
181     if (opright)
182       s1 = mapSet(rtrans, s1, Value(), opright);
183     for (unsigned p : latSets[s1])
184       latSets[s].push_back(p);
185   }
186   return s;
187 }
188 
189 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
190   assert(kAbsF <= kind && kind <= kUnary);
191   unsigned s = addSet();
192   for (unsigned p : latSets[s0]) {
193     unsigned e = addExp(kind, latPoints[p].exp, v, op);
194     latPoints.push_back(LatPoint(latPoints[p].bits, e));
195     latSets[s].push_back(latPoints.size() - 1);
196   }
197   return s;
198 }
199 
200 unsigned Merger::optimizeSet(unsigned s0) {
201   unsigned s = addSet();
202   assert(!latSets[s0].empty());
203   unsigned p0 = latSets[s0][0];
204   for (unsigned p1 : latSets[s0]) {
205     bool add = true;
206     if (p0 != p1) {
207       // Is this a straightforward copy?
208       unsigned e = latPoints[p1].exp;
209       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
210         continue;
211       // Conjunction already covered?
212       for (unsigned p2 : latSets[s]) {
213         assert(!latGT(p1, p2)); // Lj => Li would be bad
214         if (onlyDenseDiff(p2, p1)) {
215           add = false;
216           break;
217         }
218       }
219       assert(!add || latGT(p0, p1));
220     }
221     if (add)
222       latSets[s].push_back(p1);
223   }
224   for (unsigned p : latSets[s])
225     latPoints[p].simple = simplifyCond(s, p);
226   return s;
227 }
228 
229 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
230   // First determine if this lattice point is a *singleton*, i.e.,
231   // the last point in a lattice, no other is less than this one.
232   bool isSingleton = true;
233   for (unsigned p1 : latSets[s0]) {
234     if (p0 != p1 && latGT(p0, p1)) {
235       isSingleton = false;
236       break;
237     }
238   }
239   // Now apply the two basic rules.
240   BitVector simple = latPoints[p0].bits;
241   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
242   for (unsigned b = 0, be = simple.size(); b < be; b++) {
243     if (simple[b] && !isDim(b, kSparse)) {
244       if (reset)
245         simple.reset(b);
246       reset = true;
247     }
248   }
249   return simple;
250 }
251 
252 bool Merger::latGT(unsigned i, unsigned j) const {
253   const BitVector &bitsi = latPoints[i].bits;
254   const BitVector &bitsj = latPoints[j].bits;
255   assert(bitsi.size() == bitsj.size());
256   if (bitsi.count() > bitsj.count()) {
257     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
258       if (bitsj[b] && !bitsi[b])
259         return false;
260     return true;
261   }
262   return false;
263 }
264 
265 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
266   BitVector tmp = latPoints[j].bits;
267   tmp ^= latPoints[i].bits;
268   return !hasAnyDimOf(tmp, kSparse);
269 }
270 
271 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
272   for (unsigned b = 0, be = bits.size(); b < be; b++)
273     if (bits[b] && isDim(b, d))
274       return true;
275   return false;
276 }
277 
278 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
279   switch (tensorExps[e].kind) {
280   case kTensor:
281     return tensorExps[e].tensor == t;
282   case kAbsF:
283   case kAbsC:
284   case kCeilF:
285   case kFloorF:
286   case kSqrtF:
287   case kExpm1F:
288   case kLog1pF:
289   case kLog1pC:
290   case kSinF:
291   case kSinC:
292   case kTanhF:
293   case kNegF:
294   case kNegC:
295   case kNegI:
296   case kTruncF:
297   case kExtF:
298   case kCastFS:
299   case kCastFU:
300   case kCastSF:
301   case kCastUF:
302   case kCastS:
303   case kCastU:
304   case kCastIdx:
305   case kTruncI:
306   case kCIm:
307   case kCRe:
308   case kBitCast:
309     return isSingleCondition(t, tensorExps[e].children.e0);
310   case kDivF: // note: x / c only
311   case kDivC:
312   case kDivS:
313   case kDivU:
314     assert(!maybeZero(tensorExps[e].children.e1));
315     return isSingleCondition(t, tensorExps[e].children.e0);
316   case kShrS: // note: x >> inv only
317   case kShrU:
318   case kShlI:
319     assert(isInvariant(tensorExps[e].children.e1));
320     return isSingleCondition(t, tensorExps[e].children.e0);
321   case kMulF:
322   case kMulC:
323   case kMulI:
324   case kAndI:
325     if (isSingleCondition(t, tensorExps[e].children.e0))
326       return isSingleCondition(t, tensorExps[e].children.e1) ||
327              isInvariant(tensorExps[e].children.e1);
328     if (isSingleCondition(t, tensorExps[e].children.e1))
329       return isInvariant(tensorExps[e].children.e0);
330     return false;
331   case kAddF:
332   case kAddC:
333   case kAddI:
334     return isSingleCondition(t, tensorExps[e].children.e0) &&
335            isSingleCondition(t, tensorExps[e].children.e1);
336   default:
337     return false;
338   }
339 }
340 
341 #ifndef NDEBUG
342 
343 //===----------------------------------------------------------------------===//
344 // Print methods (for debugging).
345 //===----------------------------------------------------------------------===//
346 
347 static const char *kindToOpSymbol(Kind kind) {
348   switch (kind) {
349   case kTensor:
350     return "tensor";
351   case kInvariant:
352     return "invariant";
353   case kIndex:
354     return "index";
355   case kAbsF:
356   case kAbsC:
357     return "abs";
358   case kCeilF:
359     return "ceil";
360   case kFloorF:
361     return "floor";
362   case kSqrtF:
363     return "sqrt";
364   case kExpm1F:
365     return "expm1";
366   case kLog1pF:
367   case kLog1pC:
368     return "log1p";
369   case kSinF:
370   case kSinC:
371     return "sin";
372   case kTanhF:
373     return "tanh";
374   case kNegF:
375   case kNegC:
376   case kNegI:
377     return "-";
378   case kTruncF:
379   case kExtF:
380   case kCastFS:
381   case kCastFU:
382   case kCastSF:
383   case kCastUF:
384   case kCastS:
385   case kCastU:
386   case kCastIdx:
387   case kTruncI:
388   case kCIm:
389     return "complex.im";
390   case kCRe:
391     return "complex.re";
392   case kBitCast:
393     return "cast";
394   case kBinaryBranch:
395     return "binary_branch";
396   case kUnary:
397     return "unary";
398   case kMulF:
399   case kMulC:
400   case kMulI:
401     return "*";
402   case kDivF:
403   case kDivC:
404   case kDivS:
405   case kDivU:
406     return "/";
407   case kAddF:
408   case kAddC:
409   case kAddI:
410     return "+";
411   case kSubF:
412   case kSubC:
413   case kSubI:
414     return "-";
415   case kAndI:
416     return "&";
417   case kOrI:
418     return "|";
419   case kXorI:
420     return "^";
421   case kShrS:
422     return "a>>";
423   case kShrU:
424     return ">>";
425   case kShlI:
426     return "<<";
427   case kBinary:
428     return "binary";
429   }
430   llvm_unreachable("unexpected kind for symbol");
431 }
432 
433 void Merger::dumpExp(unsigned e) const {
434   switch (tensorExps[e].kind) {
435   case kTensor:
436     if (tensorExps[e].tensor == syntheticTensor)
437       llvm::dbgs() << "synthetic_";
438     else if (tensorExps[e].tensor == outTensor)
439       llvm::dbgs() << "output_";
440     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
441     break;
442   case kInvariant:
443     llvm::dbgs() << "invariant";
444     break;
445   case kIndex:
446     llvm::dbgs() << "index_" << tensorExps[e].index;
447     break;
448   case kAbsF:
449   case kCeilF:
450   case kFloorF:
451   case kSqrtF:
452   case kExpm1F:
453   case kLog1pF:
454   case kSinF:
455   case kTanhF:
456   case kNegF:
457   case kNegI:
458   case kTruncF:
459   case kExtF:
460   case kCastFS:
461   case kCastFU:
462   case kCastSF:
463   case kCastUF:
464   case kCastS:
465   case kCastU:
466   case kCastIdx:
467   case kTruncI:
468   case kBitCast:
469     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
470     dumpExp(tensorExps[e].children.e0);
471     break;
472   default:
473     llvm::dbgs() << "(";
474     dumpExp(tensorExps[e].children.e0);
475     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
476     dumpExp(tensorExps[e].children.e1);
477     llvm::dbgs() << ")";
478   }
479 }
480 
481 void Merger::dumpLat(unsigned p) const {
482   llvm::dbgs() << "lat(";
483   dumpBits(latPoints[p].bits);
484   llvm::dbgs() << " :";
485   dumpBits(latPoints[p].simple);
486   llvm::dbgs() << " : ";
487   dumpExp(latPoints[p].exp);
488   llvm::dbgs() << " )\n";
489 }
490 
491 void Merger::dumpSet(unsigned s) const {
492   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
493   for (unsigned p : latSets[s]) {
494     llvm::dbgs() << "  ";
495     dumpLat(p);
496   }
497   llvm::dbgs() << "}\n";
498 }
499 
500 void Merger::dumpBits(const BitVector &bits) const {
501   for (unsigned b = 0, be = bits.size(); b < be; b++) {
502     if (bits[b]) {
503       unsigned t = tensor(b);
504       unsigned i = index(b);
505       llvm::dbgs() << " i_" << t << "_" << i << "_";
506       switch (dims[t][i]) {
507       case kSparse:
508         llvm::dbgs() << "S";
509         break;
510       case kDense:
511         llvm::dbgs() << "D";
512         break;
513       case kSingle:
514         llvm::dbgs() << "T";
515         break;
516       case kUndef:
517         llvm::dbgs() << "U";
518         break;
519       }
520     }
521   }
522 }
523 
524 #endif // NDEBUG
525 
526 //===----------------------------------------------------------------------===//
527 // Builder methods.
528 //===----------------------------------------------------------------------===//
529 
530 unsigned Merger::buildLattices(unsigned e, unsigned i) {
531   Kind kind = tensorExps[e].kind;
532   switch (kind) {
533   case kTensor:
534   case kInvariant:
535   case kIndex: {
536     // Either the index is really used in the tensor expression, or it is
537     // set to the undefined index in that dimension. An invariant expression,
538     // a proper index value, and a truly dynamic sparse output tensor are set
539     // to a synthetic tensor with undefined indices only to ensure the
540     // iteration space is not skipped as a result of their contents.
541     unsigned s = addSet();
542     unsigned t = syntheticTensor;
543     if (kind == kTensor) {
544       t = tensorExps[e].tensor;
545       if (hasSparseOut && t == outTensor)
546         t = syntheticTensor;
547     }
548     latSets[s].push_back(addLat(t, i, e));
549     return s;
550   }
551   case kAbsF:
552   case kAbsC:
553   case kCeilF:
554   case kCIm:
555   case kCRe:
556   case kFloorF:
557   case kSqrtF:
558   case kExpm1F:
559   case kLog1pF:
560   case kLog1pC:
561   case kSinF:
562   case kSinC:
563   case kTanhF:
564   case kNegF:
565   case kNegC:
566   case kNegI:
567   case kTruncF:
568   case kExtF:
569   case kCastFS:
570   case kCastFU:
571   case kCastSF:
572   case kCastUF:
573   case kCastS:
574   case kCastU:
575   case kCastIdx:
576   case kTruncI:
577   case kBitCast:
578     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
579     // lattice set of the operand through the operator into a new set.
580     //
581     //  -y|!y | y |
582     //  --+---+---+
583     //    | 0 |-y |
584     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
585                   tensorExps[e].val);
586   case kBinaryBranch:
587     // The left or right half of a binary operation which has already
588     // been split into separate operations for each region.
589     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
590                   tensorExps[e].op);
591   case kUnary:
592     // A custom unary operation.
593     //
594     //  op y|    !y    |     y      |
595     //  ----+----------+------------+
596     //      | absent() | present(y) |
597     {
598       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
599       UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
600       Region &absentRegion = unop.absentRegion();
601 
602       if (absentRegion.empty()) {
603         // Simple mapping over existing values.
604         return mapSet(kind, child0, Value(), unop);
605       } // Use a disjunction with `unop` on the left and the absent value as an
606       // invariant on the right.
607       Block &absentBlock = absentRegion.front();
608       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
609       Value absentVal = absentYield.result();
610       unsigned rhs = addExp(kInvariant, absentVal);
611       return takeDisj(kind, child0, buildLattices(rhs, i), unop);
612     }
613   case kMulF:
614   case kMulC:
615   case kMulI:
616   case kAndI:
617     // A multiplicative operation only needs to be performed
618     // for the conjunction of sparse iteration spaces.
619     //
620     //  x*y|!y | y |
621     //  ---+---+---+
622     //  !x | 0 | 0 |
623     //   x | 0 |x*y|
624     //
625     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
626     return takeConj(kind, // take binary conjunction
627                     buildLattices(tensorExps[e].children.e0, i),
628                     buildLattices(tensorExps[e].children.e1, i));
629   case kDivF:
630   case kDivC:
631   case kDivS:
632   case kDivU:
633     // A division is tricky, since 0/0, 0/c, c/0 all have
634     // specific outcomes for floating-point and integers.
635     // Thus, we need to traverse the full iteration space.
636     //
637     //  x/y|!y | y |
638     //  ---+---+---+
639     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
640     //   x |x/0|x/y|  INT: x/0=exception for any x
641     //
642     // TODO: for now we "fixed" this by only accepting x/c cases
643     //       during expression building, so that the conjunction
644     //       rules applies (viz. x/c = x*(1/c) as far as lattice
645     //       construction is concerned).
646     assert(!maybeZero(tensorExps[e].children.e1));
647     return takeConj(kind, // take binary conjunction
648                     buildLattices(tensorExps[e].children.e0, i),
649                     buildLattices(tensorExps[e].children.e1, i));
650   case kAddF:
651   case kAddC:
652   case kAddI:
653   case kSubF:
654   case kSubC:
655   case kSubI:
656   case kOrI:
657   case kXorI:
658     // An additive operation needs to be performed
659     // for the disjunction of sparse iteration spaces.
660     //
661     //  x+y|!y | y |    x-y|!y | y |
662     //  ---+---+---+    ---+---+---+
663     //  !x | 0 | y |    !x | 0 |-y |
664     //   x | x |x+y|     x | x |x-y|
665     return takeDisj(kind, // take binary disjunction
666                     buildLattices(tensorExps[e].children.e0, i),
667                     buildLattices(tensorExps[e].children.e1, i));
668   case kShrS:
669   case kShrU:
670   case kShlI:
671     // A shift operation by an invariant amount (viz. tensor expressions
672     // can only occur at the left-hand-side of the operator) can be handled
673     // with the conjuction rule.
674     assert(isInvariant(tensorExps[e].children.e1));
675     return takeConj(kind, // take binary conjunction
676                     buildLattices(tensorExps[e].children.e0, i),
677                     buildLattices(tensorExps[e].children.e1, i));
678   case kBinary:
679     // A custom binary operation.
680     //
681     //  x op y|   !y    |       y      |
682     //  ------+---------+--------------+
683     //    !x  |  empty  |   right(y)   |
684     //     x  | left(x) | overlap(x,y) |
685     {
686       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
687       unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
688       BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
689       Region &leftRegion = binop.leftRegion();
690       Region &rightRegion = binop.rightRegion();
691       // Left Region.
692       Operation *leftYield = nullptr;
693       if (!leftRegion.empty()) {
694         Block &leftBlock = leftRegion.front();
695         leftYield = leftBlock.getTerminator();
696       }
697       // Right Region.
698       Operation *rightYield = nullptr;
699       if (!rightRegion.empty()) {
700         Block &rightBlock = rightRegion.front();
701         rightYield = rightBlock.getTerminator();
702       }
703       bool includeLeft = binop.left_identity() || !leftRegion.empty();
704       bool includeRight = binop.right_identity() || !rightRegion.empty();
705       return takeCombi(kBinary, child0, child1, binop, includeLeft,
706                        kBinaryBranch, leftYield, includeRight, kBinaryBranch,
707                        rightYield);
708     }
709   }
710   llvm_unreachable("unexpected expression kind");
711 }
712 
713 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
714   Operation *yield = op.region().front().getTerminator();
715   return buildTensorExp(op, yield->getOperand(0));
716 }
717 
718 /// Only returns false if we are certain this is a nonzero.
719 bool Merger::maybeZero(unsigned e) const {
720   if (tensorExps[e].kind == kInvariant) {
721     if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
722       ArrayAttr arrayAttr = c.getValue();
723       return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
724              arrayAttr[0].cast<FloatAttr>().getValue().isZero();
725     }
726     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
727       return c.value() == 0;
728     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
729       return c.value().isZero();
730   }
731   return true;
732 }
733 
734 bool Merger::isInvariant(unsigned e) const {
735   return tensorExps[e].kind == kInvariant;
736 }
737 
738 Type Merger::inferType(unsigned e, Value src) {
739   // Obtain the destination type from the cast node.
740   Type dtp = tensorExps[e].val.getType();
741   // Inspect source type. For vector types, apply the same
742   // vectorization to the destination type.
743   if (auto vtp = src.getType().dyn_cast<VectorType>())
744     return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
745   return dtp;
746 }
747 
748 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
749   if (auto arg = v.dyn_cast<BlockArgument>()) {
750     unsigned argN = arg.getArgNumber();
751     // Any argument of the generic op that is not marked as a scalar
752     // argument is considered a tensor, indexed by the implicit loop
753     // bounds. This includes rank-0 tensor arguments.
754     if (arg.getOwner()->getParentOp() == op) {
755       OpOperand *t = op.getInputAndOutputOperands()[argN];
756       if (!op.isScalar(t))
757         return addExp(kTensor, argN);
758       v = t->get(); // get scalar value
759     }
760     // Any other argument (marked as scalar argument for the generic op
761     // or belonging to an enveloping op) is considered invariant.
762     return addExp(kInvariant, v);
763   }
764   // Something defined outside is invariant.
765   Operation *def = v.getDefiningOp();
766   if (def->getBlock() != &op.region().front())
767     return addExp(kInvariant, v);
768   // Construct index operations.
769   if (def->getNumOperands() == 0) {
770     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
771       return addExp(kIndex, indexOp.dim());
772   }
773   // Construct unary operations if subexpression can be built.
774   if (def->getNumOperands() == 1) {
775     auto x = buildTensorExp(op, def->getOperand(0));
776     if (x.hasValue()) {
777       unsigned e = x.getValue();
778       if (isa<math::AbsOp>(def))
779         return addExp(kAbsF, e);
780       if (isa<complex::AbsOp>(def))
781         return addExp(kAbsC, e);
782       if (isa<math::CeilOp>(def))
783         return addExp(kCeilF, e);
784       if (isa<math::FloorOp>(def))
785         return addExp(kFloorF, e);
786       if (isa<math::SqrtOp>(def))
787         return addExp(kSqrtF, e);
788       if (isa<math::ExpM1Op>(def))
789         return addExp(kExpm1F, e);
790       if (isa<math::Log1pOp>(def))
791         return addExp(kLog1pF, e);
792       if (isa<complex::Log1pOp>(def))
793         return addExp(kLog1pC, e);
794       if (isa<math::SinOp>(def))
795         return addExp(kSinF, e);
796       if (isa<complex::SinOp>(def))
797         return addExp(kSinC, e);
798       if (isa<math::TanhOp>(def))
799         return addExp(kTanhF, e);
800       if (isa<arith::NegFOp>(def))
801         return addExp(kNegF, e); // no negi in std
802       if (isa<complex::NegOp>(def))
803         return addExp(kNegC, e);
804       if (isa<arith::TruncFOp>(def))
805         return addExp(kTruncF, e, v);
806       if (isa<arith::ExtFOp>(def))
807         return addExp(kExtF, e, v);
808       if (isa<arith::FPToSIOp>(def))
809         return addExp(kCastFS, e, v);
810       if (isa<arith::FPToUIOp>(def))
811         return addExp(kCastFU, e, v);
812       if (isa<arith::SIToFPOp>(def))
813         return addExp(kCastSF, e, v);
814       if (isa<arith::UIToFPOp>(def))
815         return addExp(kCastUF, e, v);
816       if (isa<arith::ExtSIOp>(def))
817         return addExp(kCastS, e, v);
818       if (isa<arith::ExtUIOp>(def))
819         return addExp(kCastU, e, v);
820       if (isa<arith::IndexCastOp>(def))
821         return addExp(kCastIdx, e, v);
822       if (isa<arith::TruncIOp>(def))
823         return addExp(kTruncI, e, v);
824       if (isa<complex::ImOp>(def))
825         return addExp(kCIm, e);
826       if (isa<complex::ReOp>(def))
827         return addExp(kCRe, e);
828       if (isa<arith::BitcastOp>(def))
829         return addExp(kBitCast, e, v);
830       if (isa<sparse_tensor::UnaryOp>(def))
831         return addExp(kUnary, e, Value(), def);
832     }
833   }
834   // Construct binary operations if subexpressions can be built.
835   // See buildLattices() for an explanation of rejecting certain
836   // division and shift operations
837   if (def->getNumOperands() == 2) {
838     auto x = buildTensorExp(op, def->getOperand(0));
839     auto y = buildTensorExp(op, def->getOperand(1));
840     if (x.hasValue() && y.hasValue()) {
841       unsigned e0 = x.getValue();
842       unsigned e1 = y.getValue();
843       if (isa<arith::MulFOp>(def))
844         return addExp(kMulF, e0, e1);
845       if (isa<complex::MulOp>(def))
846         return addExp(kMulC, e0, e1);
847       if (isa<arith::MulIOp>(def))
848         return addExp(kMulI, e0, e1);
849       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
850         return addExp(kDivF, e0, e1);
851       if (isa<complex::DivOp>(def) && !maybeZero(e1))
852         return addExp(kDivC, e0, e1);
853       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
854         return addExp(kDivS, e0, e1);
855       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
856         return addExp(kDivU, e0, e1);
857       if (isa<arith::AddFOp>(def))
858         return addExp(kAddF, e0, e1);
859       if (isa<complex::AddOp>(def))
860         return addExp(kAddC, e0, e1);
861       if (isa<arith::AddIOp>(def))
862         return addExp(kAddI, e0, e1);
863       if (isa<arith::SubFOp>(def))
864         return addExp(kSubF, e0, e1);
865       if (isa<complex::SubOp>(def))
866         return addExp(kSubC, e0, e1);
867       if (isa<arith::SubIOp>(def))
868         return addExp(kSubI, e0, e1);
869       if (isa<arith::AndIOp>(def))
870         return addExp(kAndI, e0, e1);
871       if (isa<arith::OrIOp>(def))
872         return addExp(kOrI, e0, e1);
873       if (isa<arith::XOrIOp>(def))
874         return addExp(kXorI, e0, e1);
875       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
876         return addExp(kShrS, e0, e1);
877       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
878         return addExp(kShrU, e0, e1);
879       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
880         return addExp(kShlI, e0, e1);
881       if (isa<sparse_tensor::BinaryOp>(def))
882         return addExp(kBinary, e0, e1, Value(), def);
883     }
884   }
885   // Cannot build.
886   return None;
887 }
888 
889 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
890                            ValueRange vals) {
891   // Make a clone of overlap region.
892   Region tmpRegion;
893   BlockAndValueMapping mapper;
894   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
895   Block &clonedBlock = tmpRegion.front();
896   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
897   // Merge cloned block and return yield value.
898   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
899   rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
900   Value val = clonedYield.result();
901   rewriter.eraseOp(clonedYield);
902   rewriter.eraseOp(placeholder);
903   return val;
904 }
905 
906 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
907                                Operation *op, Value v0) {
908   if (!v0)
909     // Empty input value must be propagated.
910     return Value();
911   UnaryOp unop = cast<UnaryOp>(op);
912   Region &presentRegion = unop.presentRegion();
913   if (presentRegion.empty())
914     // Uninitialized Value() will be interpreted as missing data in the
915     // output.
916     return Value();
917   return insertYieldOp(rewriter, loc, presentRegion, {v0});
918 }
919 
920 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
921                                 Operation *op, Value v0, Value v1) {
922   if (!v0 || !v1)
923     // Empty input values must be propagated.
924     return Value();
925   BinaryOp binop = cast<BinaryOp>(op);
926   Region &overlapRegion = binop.overlapRegion();
927   if (overlapRegion.empty())
928     // Uninitialized Value() will be interpreted as missing data in the
929     // output.
930     return Value();
931   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
932 }
933 
934 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
935                        Value v0, Value v1) {
936   switch (tensorExps[e].kind) {
937   case kTensor:
938   case kInvariant:
939   case kIndex:
940     llvm_unreachable("unexpected non-op");
941   // Unary ops.
942   case kAbsF:
943     return rewriter.create<math::AbsOp>(loc, v0);
944   case kAbsC: {
945     auto type = v0.getType().template cast<ComplexType>();
946     auto eltType = type.getElementType().template cast<FloatType>();
947     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
948   }
949   case kCeilF:
950     return rewriter.create<math::CeilOp>(loc, v0);
951   case kFloorF:
952     return rewriter.create<math::FloorOp>(loc, v0);
953   case kSqrtF:
954     return rewriter.create<math::SqrtOp>(loc, v0);
955   case kExpm1F:
956     return rewriter.create<math::ExpM1Op>(loc, v0);
957   case kLog1pF:
958     return rewriter.create<math::Log1pOp>(loc, v0);
959   case kLog1pC:
960     return rewriter.create<complex::Log1pOp>(loc, v0);
961   case kSinF:
962     return rewriter.create<math::SinOp>(loc, v0);
963   case kSinC:
964     return rewriter.create<complex::SinOp>(loc, v0);
965   case kTanhF:
966     return rewriter.create<math::TanhOp>(loc, v0);
967   case kNegF:
968     return rewriter.create<arith::NegFOp>(loc, v0);
969   case kNegC:
970     return rewriter.create<complex::NegOp>(loc, v0);
971   case kNegI: // no negi in std
972     return rewriter.create<arith::SubIOp>(
973         loc,
974         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
975                                            rewriter.getZeroAttr(v0.getType())),
976         v0);
977   case kTruncF:
978     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
979   case kExtF:
980     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
981   case kCastFS:
982     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
983   case kCastFU:
984     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
985   case kCastSF:
986     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
987   case kCastUF:
988     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
989   case kCastS:
990     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
991   case kCastU:
992     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
993   case kCastIdx:
994     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
995   case kTruncI:
996     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
997   case kCIm:
998   case kCRe: {
999     auto type = v0.getType().template cast<ComplexType>();
1000     auto eltType = type.getElementType().template cast<FloatType>();
1001     if (tensorExps[e].kind == kCIm)
1002       return rewriter.create<complex::ImOp>(loc, eltType, v0);
1003 
1004     return rewriter.create<complex::ReOp>(loc, eltType, v0);
1005   }
1006   case kBitCast:
1007     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1008   // Binary ops.
1009   case kMulF:
1010     return rewriter.create<arith::MulFOp>(loc, v0, v1);
1011   case kMulC:
1012     return rewriter.create<complex::MulOp>(loc, v0, v1);
1013   case kMulI:
1014     return rewriter.create<arith::MulIOp>(loc, v0, v1);
1015   case kDivF:
1016     return rewriter.create<arith::DivFOp>(loc, v0, v1);
1017   case kDivC:
1018     return rewriter.create<complex::DivOp>(loc, v0, v1);
1019   case kDivS:
1020     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1021   case kDivU:
1022     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1023   case kAddF:
1024     return rewriter.create<arith::AddFOp>(loc, v0, v1);
1025   case kAddC:
1026     return rewriter.create<complex::AddOp>(loc, v0, v1);
1027   case kAddI:
1028     return rewriter.create<arith::AddIOp>(loc, v0, v1);
1029   case kSubF:
1030     return rewriter.create<arith::SubFOp>(loc, v0, v1);
1031   case kSubC:
1032     return rewriter.create<complex::SubOp>(loc, v0, v1);
1033   case kSubI:
1034     return rewriter.create<arith::SubIOp>(loc, v0, v1);
1035   case kAndI:
1036     return rewriter.create<arith::AndIOp>(loc, v0, v1);
1037   case kOrI:
1038     return rewriter.create<arith::OrIOp>(loc, v0, v1);
1039   case kXorI:
1040     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1041   case kShrS:
1042     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1043   case kShrU:
1044     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1045   case kShlI:
1046     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1047   // Semiring ops with custom logic.
1048   case kBinaryBranch:
1049     return insertYieldOp(rewriter, loc,
1050                          *tensorExps[e].op->getBlock()->getParent(), {v0});
1051   case kUnary:
1052     return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
1053   case kBinary:
1054     return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
1055   }
1056   llvm_unreachable("unexpected expression kind in build");
1057 }
1058 
1059 } // namespace sparse_tensor
1060 } // namespace mlir
1061