1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 
18 namespace mlir {
19 namespace sparse_tensor {
20 
21 //===----------------------------------------------------------------------===//
22 // Constructors.
23 //===----------------------------------------------------------------------===//
24 
25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
26     : kind(k), val(v), op(o) {
27   switch (kind) {
28   case kTensor:
29     assert(x != -1u && y == -1u && !v && !o);
30     tensor = x;
31     break;
32   case kInvariant:
33     assert(x == -1u && y == -1u && v && !o);
34     break;
35   case kIndex:
36     assert(x != -1u && y == -1u && !v && !o);
37     index = x;
38     break;
39   case kAbsF:
40   case kAbsC:
41   case kCeilF:
42   case kFloorF:
43   case kSqrtF:
44   case kSqrtC:
45   case kExpm1F:
46   case kExpm1C:
47   case kLog1pF:
48   case kLog1pC:
49   case kSinF:
50   case kSinC:
51   case kTanhF:
52   case kTanhC:
53   case kNegF:
54   case kNegC:
55   case kNegI:
56   case kCIm:
57   case kCRe:
58     assert(x != -1u && y == -1u && !v && !o);
59     children.e0 = x;
60     children.e1 = y;
61     break;
62   case kTruncF:
63   case kExtF:
64   case kCastFS:
65   case kCastFU:
66   case kCastSF:
67   case kCastUF:
68   case kCastS:
69   case kCastU:
70   case kCastIdx:
71   case kTruncI:
72   case kBitCast:
73     assert(x != -1u && y == -1u && v && !o);
74     children.e0 = x;
75     children.e1 = y;
76     break;
77   case kBinaryBranch:
78     assert(x != -1u && y == -1u && !v && o);
79     children.e0 = x;
80     children.e1 = y;
81     break;
82   case kUnary:
83     // No assertion on y can be made, as the branching paths involve both
84     // a unary (mapSet) and binary (takeDisj) pathway.
85     assert(x != -1u && !v && o);
86     children.e0 = x;
87     children.e1 = y;
88     break;
89   case kBinary:
90     assert(x != -1u && y != -1u && !v && o);
91     children.e0 = x;
92     children.e1 = y;
93     break;
94   default:
95     assert(x != -1u && y != -1u && !v && !o);
96     children.e0 = x;
97     children.e1 = y;
98     break;
99   }
100 }
101 
102 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
103     : bits(n, false), simple(), exp(e) {
104   bits.set(b);
105 }
106 
107 LatPoint::LatPoint(const BitVector &b, unsigned e)
108     : bits(b), simple(), exp(e) {}
109 
110 //===----------------------------------------------------------------------===//
111 // Lattice methods.
112 //===----------------------------------------------------------------------===//
113 
114 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
115                         Operation *op) {
116   unsigned e = tensorExps.size();
117   tensorExps.push_back(TensorExp(k, e0, e1, v, op));
118   return e;
119 }
120 
121 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
122   assert(t < numTensors && i < numLoops);
123   unsigned p = latPoints.size();
124   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
125   return p;
126 }
127 
128 unsigned Merger::addSet() {
129   unsigned s = latSets.size();
130   latSets.emplace_back(SmallVector<unsigned, 16>());
131   return s;
132 }
133 
134 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
135                               Operation *op) {
136   unsigned p = latPoints.size();
137   BitVector nb = BitVector(latPoints[p0].bits);
138   nb |= latPoints[p1].bits;
139   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
140   latPoints.push_back(LatPoint(nb, e));
141   return p;
142 }
143 
144 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
145   unsigned s = addSet();
146   for (unsigned p0 : latSets[s0])
147     for (unsigned p1 : latSets[s1])
148       latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
149   return s;
150 }
151 
152 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
153   unsigned s = takeConj(kind, s0, s1, op);
154   // Followed by all in s0.
155   for (unsigned p : latSets[s0])
156     latSets[s].push_back(p);
157   // Map binary 0-y to unary -y.
158   // TODO: move this if-else logic into buildLattices
159   if (kind == kSubF)
160     s1 = mapSet(kNegF, s1);
161   else if (kind == kSubC)
162     s1 = mapSet(kNegC, s1);
163   else if (kind == kSubI)
164     s1 = mapSet(kNegI, s1);
165   // Followed by all in s1.
166   for (unsigned p : latSets[s1])
167     latSets[s].push_back(p);
168   return s;
169 }
170 
171 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
172                            bool includeLeft, Kind ltrans, Operation *opleft,
173                            bool includeRight, Kind rtrans, Operation *opright) {
174   unsigned s = takeConj(kind, s0, s1, orig);
175   // Left Region.
176   if (includeLeft) {
177     if (opleft)
178       s0 = mapSet(ltrans, s0, Value(), opleft);
179     for (unsigned p : latSets[s0])
180       latSets[s].push_back(p);
181   }
182   // Right Region.
183   if (includeRight) {
184     if (opright)
185       s1 = mapSet(rtrans, s1, Value(), opright);
186     for (unsigned p : latSets[s1])
187       latSets[s].push_back(p);
188   }
189   return s;
190 }
191 
192 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
193   assert(kAbsF <= kind && kind <= kUnary);
194   unsigned s = addSet();
195   for (unsigned p : latSets[s0]) {
196     unsigned e = addExp(kind, latPoints[p].exp, v, op);
197     latPoints.push_back(LatPoint(latPoints[p].bits, e));
198     latSets[s].push_back(latPoints.size() - 1);
199   }
200   return s;
201 }
202 
203 unsigned Merger::optimizeSet(unsigned s0) {
204   unsigned s = addSet();
205   assert(!latSets[s0].empty());
206   unsigned p0 = latSets[s0][0];
207   for (unsigned p1 : latSets[s0]) {
208     bool add = true;
209     if (p0 != p1) {
210       // Is this a straightforward copy?
211       unsigned e = latPoints[p1].exp;
212       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
213         continue;
214       // Conjunction already covered?
215       for (unsigned p2 : latSets[s]) {
216         assert(!latGT(p1, p2)); // Lj => Li would be bad
217         if (onlyDenseDiff(p2, p1)) {
218           add = false;
219           break;
220         }
221       }
222       assert(!add || latGT(p0, p1));
223     }
224     if (add)
225       latSets[s].push_back(p1);
226   }
227   for (unsigned p : latSets[s])
228     latPoints[p].simple = simplifyCond(s, p);
229   return s;
230 }
231 
232 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
233   // First determine if this lattice point is a *singleton*, i.e.,
234   // the last point in a lattice, no other is less than this one.
235   bool isSingleton = true;
236   for (unsigned p1 : latSets[s0]) {
237     if (p0 != p1 && latGT(p0, p1)) {
238       isSingleton = false;
239       break;
240     }
241   }
242   // Now apply the two basic rules.
243   BitVector simple = latPoints[p0].bits;
244   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
245   for (unsigned b = 0, be = simple.size(); b < be; b++) {
246     if (simple[b] && !isDim(b, kSparse)) {
247       if (reset)
248         simple.reset(b);
249       reset = true;
250     }
251   }
252   return simple;
253 }
254 
255 bool Merger::latGT(unsigned i, unsigned j) const {
256   const BitVector &bitsi = latPoints[i].bits;
257   const BitVector &bitsj = latPoints[j].bits;
258   assert(bitsi.size() == bitsj.size());
259   if (bitsi.count() > bitsj.count()) {
260     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
261       if (bitsj[b] && !bitsi[b])
262         return false;
263     return true;
264   }
265   return false;
266 }
267 
268 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
269   BitVector tmp = latPoints[j].bits;
270   tmp ^= latPoints[i].bits;
271   return !hasAnyDimOf(tmp, kSparse);
272 }
273 
274 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
275   for (unsigned b = 0, be = bits.size(); b < be; b++)
276     if (bits[b] && isDim(b, d))
277       return true;
278   return false;
279 }
280 
281 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
282   switch (tensorExps[e].kind) {
283   case kTensor:
284     return tensorExps[e].tensor == t;
285   case kAbsF:
286   case kAbsC:
287   case kCeilF:
288   case kFloorF:
289   case kSqrtF:
290   case kSqrtC:
291   case kExpm1F:
292   case kExpm1C:
293   case kLog1pF:
294   case kLog1pC:
295   case kSinF:
296   case kSinC:
297   case kTanhF:
298   case kTanhC:
299   case kNegF:
300   case kNegC:
301   case kNegI:
302   case kTruncF:
303   case kExtF:
304   case kCastFS:
305   case kCastFU:
306   case kCastSF:
307   case kCastUF:
308   case kCastS:
309   case kCastU:
310   case kCastIdx:
311   case kTruncI:
312   case kCIm:
313   case kCRe:
314   case kBitCast:
315     return isSingleCondition(t, tensorExps[e].children.e0);
316   case kDivF: // note: x / c only
317   case kDivC:
318   case kDivS:
319   case kDivU:
320     assert(!maybeZero(tensorExps[e].children.e1));
321     return isSingleCondition(t, tensorExps[e].children.e0);
322   case kShrS: // note: x >> inv only
323   case kShrU:
324   case kShlI:
325     assert(isInvariant(tensorExps[e].children.e1));
326     return isSingleCondition(t, tensorExps[e].children.e0);
327   case kMulF:
328   case kMulC:
329   case kMulI:
330   case kAndI:
331     if (isSingleCondition(t, tensorExps[e].children.e0))
332       return isSingleCondition(t, tensorExps[e].children.e1) ||
333              isInvariant(tensorExps[e].children.e1);
334     if (isSingleCondition(t, tensorExps[e].children.e1))
335       return isInvariant(tensorExps[e].children.e0);
336     return false;
337   case kAddF:
338   case kAddC:
339   case kAddI:
340     return isSingleCondition(t, tensorExps[e].children.e0) &&
341            isSingleCondition(t, tensorExps[e].children.e1);
342   default:
343     return false;
344   }
345 }
346 
347 #ifndef NDEBUG
348 
349 //===----------------------------------------------------------------------===//
350 // Print methods (for debugging).
351 //===----------------------------------------------------------------------===//
352 
353 static const char *kindToOpSymbol(Kind kind) {
354   switch (kind) {
355   case kTensor:
356     return "tensor";
357   case kInvariant:
358     return "invariant";
359   case kIndex:
360     return "index";
361   case kAbsF:
362   case kAbsC:
363     return "abs";
364   case kCeilF:
365     return "ceil";
366   case kFloorF:
367     return "floor";
368   case kSqrtF:
369   case kSqrtC:
370     return "sqrt";
371   case kExpm1F:
372   case kExpm1C:
373     return "expm1";
374   case kLog1pF:
375   case kLog1pC:
376     return "log1p";
377   case kSinF:
378   case kSinC:
379     return "sin";
380   case kTanhF:
381   case kTanhC:
382     return "tanh";
383   case kNegF:
384   case kNegC:
385   case kNegI:
386     return "-";
387   case kTruncF:
388   case kExtF:
389   case kCastFS:
390   case kCastFU:
391   case kCastSF:
392   case kCastUF:
393   case kCastS:
394   case kCastU:
395   case kCastIdx:
396   case kTruncI:
397   case kCIm:
398     return "complex.im";
399   case kCRe:
400     return "complex.re";
401   case kBitCast:
402     return "cast";
403   case kBinaryBranch:
404     return "binary_branch";
405   case kUnary:
406     return "unary";
407   case kMulF:
408   case kMulC:
409   case kMulI:
410     return "*";
411   case kDivF:
412   case kDivC:
413   case kDivS:
414   case kDivU:
415     return "/";
416   case kAddF:
417   case kAddC:
418   case kAddI:
419     return "+";
420   case kSubF:
421   case kSubC:
422   case kSubI:
423     return "-";
424   case kAndI:
425     return "&";
426   case kOrI:
427     return "|";
428   case kXorI:
429     return "^";
430   case kShrS:
431     return "a>>";
432   case kShrU:
433     return ">>";
434   case kShlI:
435     return "<<";
436   case kBinary:
437     return "binary";
438   }
439   llvm_unreachable("unexpected kind for symbol");
440 }
441 
442 void Merger::dumpExp(unsigned e) const {
443   switch (tensorExps[e].kind) {
444   case kTensor:
445     if (tensorExps[e].tensor == syntheticTensor)
446       llvm::dbgs() << "synthetic_";
447     else if (tensorExps[e].tensor == outTensor)
448       llvm::dbgs() << "output_";
449     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
450     break;
451   case kInvariant:
452     llvm::dbgs() << "invariant";
453     break;
454   case kIndex:
455     llvm::dbgs() << "index_" << tensorExps[e].index;
456     break;
457   case kAbsF:
458   case kCeilF:
459   case kFloorF:
460   case kSqrtF:
461   case kSqrtC:
462   case kExpm1F:
463   case kExpm1C:
464   case kLog1pF:
465   case kSinF:
466   case kTanhF:
467   case kTanhC:
468   case kNegF:
469   case kNegI:
470   case kTruncF:
471   case kExtF:
472   case kCastFS:
473   case kCastFU:
474   case kCastSF:
475   case kCastUF:
476   case kCastS:
477   case kCastU:
478   case kCastIdx:
479   case kTruncI:
480   case kBitCast:
481     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
482     dumpExp(tensorExps[e].children.e0);
483     break;
484   default:
485     llvm::dbgs() << "(";
486     dumpExp(tensorExps[e].children.e0);
487     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
488     dumpExp(tensorExps[e].children.e1);
489     llvm::dbgs() << ")";
490   }
491 }
492 
493 void Merger::dumpLat(unsigned p) const {
494   llvm::dbgs() << "lat(";
495   dumpBits(latPoints[p].bits);
496   llvm::dbgs() << " :";
497   dumpBits(latPoints[p].simple);
498   llvm::dbgs() << " : ";
499   dumpExp(latPoints[p].exp);
500   llvm::dbgs() << " )\n";
501 }
502 
503 void Merger::dumpSet(unsigned s) const {
504   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
505   for (unsigned p : latSets[s]) {
506     llvm::dbgs() << "  ";
507     dumpLat(p);
508   }
509   llvm::dbgs() << "}\n";
510 }
511 
512 void Merger::dumpBits(const BitVector &bits) const {
513   for (unsigned b = 0, be = bits.size(); b < be; b++) {
514     if (bits[b]) {
515       unsigned t = tensor(b);
516       unsigned i = index(b);
517       llvm::dbgs() << " i_" << t << "_" << i << "_";
518       switch (dims[t][i]) {
519       case kSparse:
520         llvm::dbgs() << "S";
521         break;
522       case kDense:
523         llvm::dbgs() << "D";
524         break;
525       case kSingle:
526         llvm::dbgs() << "T";
527         break;
528       case kUndef:
529         llvm::dbgs() << "U";
530         break;
531       }
532     }
533   }
534 }
535 
536 #endif // NDEBUG
537 
538 //===----------------------------------------------------------------------===//
539 // Builder methods.
540 //===----------------------------------------------------------------------===//
541 
542 unsigned Merger::buildLattices(unsigned e, unsigned i) {
543   Kind kind = tensorExps[e].kind;
544   switch (kind) {
545   case kTensor:
546   case kInvariant:
547   case kIndex: {
548     // Either the index is really used in the tensor expression, or it is
549     // set to the undefined index in that dimension. An invariant expression,
550     // a proper index value, and a truly dynamic sparse output tensor are set
551     // to a synthetic tensor with undefined indices only to ensure the
552     // iteration space is not skipped as a result of their contents.
553     unsigned s = addSet();
554     unsigned t = syntheticTensor;
555     if (kind == kTensor) {
556       t = tensorExps[e].tensor;
557       if (hasSparseOut && t == outTensor)
558         t = syntheticTensor;
559     }
560     latSets[s].push_back(addLat(t, i, e));
561     return s;
562   }
563   case kAbsF:
564   case kAbsC:
565   case kCeilF:
566   case kCIm:
567   case kCRe:
568   case kFloorF:
569   case kSqrtF:
570   case kSqrtC:
571   case kExpm1F:
572   case kExpm1C:
573   case kLog1pF:
574   case kLog1pC:
575   case kSinF:
576   case kSinC:
577   case kTanhF:
578   case kTanhC:
579   case kNegF:
580   case kNegC:
581   case kNegI:
582   case kTruncF:
583   case kExtF:
584   case kCastFS:
585   case kCastFU:
586   case kCastSF:
587   case kCastUF:
588   case kCastS:
589   case kCastU:
590   case kCastIdx:
591   case kTruncI:
592   case kBitCast:
593     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
594     // lattice set of the operand through the operator into a new set.
595     //
596     //  -y|!y | y |
597     //  --+---+---+
598     //    | 0 |-y |
599     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
600                   tensorExps[e].val);
601   case kBinaryBranch:
602     // The left or right half of a binary operation which has already
603     // been split into separate operations for each region.
604     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
605                   tensorExps[e].op);
606   case kUnary:
607     // A custom unary operation.
608     //
609     //  op y|    !y    |     y      |
610     //  ----+----------+------------+
611     //      | absent() | present(y) |
612     {
613       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
614       UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
615       Region &absentRegion = unop.absentRegion();
616 
617       if (absentRegion.empty()) {
618         // Simple mapping over existing values.
619         return mapSet(kind, child0, Value(), unop);
620       } // Use a disjunction with `unop` on the left and the absent value as an
621       // invariant on the right.
622       Block &absentBlock = absentRegion.front();
623       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
624       Value absentVal = absentYield.result();
625       unsigned rhs = addExp(kInvariant, absentVal);
626       return takeDisj(kind, child0, buildLattices(rhs, i), unop);
627     }
628   case kMulF:
629   case kMulC:
630   case kMulI:
631   case kAndI:
632     // A multiplicative operation only needs to be performed
633     // for the conjunction of sparse iteration spaces.
634     //
635     //  x*y|!y | y |
636     //  ---+---+---+
637     //  !x | 0 | 0 |
638     //   x | 0 |x*y|
639     //
640     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
641     return takeConj(kind, // take binary conjunction
642                     buildLattices(tensorExps[e].children.e0, i),
643                     buildLattices(tensorExps[e].children.e1, i));
644   case kDivF:
645   case kDivC:
646   case kDivS:
647   case kDivU:
648     // A division is tricky, since 0/0, 0/c, c/0 all have
649     // specific outcomes for floating-point and integers.
650     // Thus, we need to traverse the full iteration space.
651     //
652     //  x/y|!y | y |
653     //  ---+---+---+
654     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
655     //   x |x/0|x/y|  INT: x/0=exception for any x
656     //
657     // TODO: for now we "fixed" this by only accepting x/c cases
658     //       during expression building, so that the conjunction
659     //       rules applies (viz. x/c = x*(1/c) as far as lattice
660     //       construction is concerned).
661     assert(!maybeZero(tensorExps[e].children.e1));
662     return takeConj(kind, // take binary conjunction
663                     buildLattices(tensorExps[e].children.e0, i),
664                     buildLattices(tensorExps[e].children.e1, i));
665   case kAddF:
666   case kAddC:
667   case kAddI:
668   case kSubF:
669   case kSubC:
670   case kSubI:
671   case kOrI:
672   case kXorI:
673     // An additive operation needs to be performed
674     // for the disjunction of sparse iteration spaces.
675     //
676     //  x+y|!y | y |    x-y|!y | y |
677     //  ---+---+---+    ---+---+---+
678     //  !x | 0 | y |    !x | 0 |-y |
679     //   x | x |x+y|     x | x |x-y|
680     return takeDisj(kind, // take binary disjunction
681                     buildLattices(tensorExps[e].children.e0, i),
682                     buildLattices(tensorExps[e].children.e1, i));
683   case kShrS:
684   case kShrU:
685   case kShlI:
686     // A shift operation by an invariant amount (viz. tensor expressions
687     // can only occur at the left-hand-side of the operator) can be handled
688     // with the conjuction rule.
689     assert(isInvariant(tensorExps[e].children.e1));
690     return takeConj(kind, // take binary conjunction
691                     buildLattices(tensorExps[e].children.e0, i),
692                     buildLattices(tensorExps[e].children.e1, i));
693   case kBinary:
694     // A custom binary operation.
695     //
696     //  x op y|   !y    |       y      |
697     //  ------+---------+--------------+
698     //    !x  |  empty  |   right(y)   |
699     //     x  | left(x) | overlap(x,y) |
700     {
701       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
702       unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
703       BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
704       Region &leftRegion = binop.leftRegion();
705       Region &rightRegion = binop.rightRegion();
706       // Left Region.
707       Operation *leftYield = nullptr;
708       if (!leftRegion.empty()) {
709         Block &leftBlock = leftRegion.front();
710         leftYield = leftBlock.getTerminator();
711       }
712       // Right Region.
713       Operation *rightYield = nullptr;
714       if (!rightRegion.empty()) {
715         Block &rightBlock = rightRegion.front();
716         rightYield = rightBlock.getTerminator();
717       }
718       bool includeLeft = binop.left_identity() || !leftRegion.empty();
719       bool includeRight = binop.right_identity() || !rightRegion.empty();
720       return takeCombi(kBinary, child0, child1, binop, includeLeft,
721                        kBinaryBranch, leftYield, includeRight, kBinaryBranch,
722                        rightYield);
723     }
724   }
725   llvm_unreachable("unexpected expression kind");
726 }
727 
728 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
729   Operation *yield = op.region().front().getTerminator();
730   return buildTensorExp(op, yield->getOperand(0));
731 }
732 
733 /// Only returns false if we are certain this is a nonzero.
734 bool Merger::maybeZero(unsigned e) const {
735   if (tensorExps[e].kind == kInvariant) {
736     if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
737       ArrayAttr arrayAttr = c.getValue();
738       return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
739              arrayAttr[0].cast<FloatAttr>().getValue().isZero();
740     }
741     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
742       return c.value() == 0;
743     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
744       return c.value().isZero();
745   }
746   return true;
747 }
748 
749 bool Merger::isInvariant(unsigned e) const {
750   return tensorExps[e].kind == kInvariant;
751 }
752 
753 Type Merger::inferType(unsigned e, Value src) {
754   // Obtain the destination type from the cast node.
755   Type dtp = tensorExps[e].val.getType();
756   // Inspect source type. For vector types, apply the same
757   // vectorization to the destination type.
758   if (auto vtp = src.getType().dyn_cast<VectorType>())
759     return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
760   return dtp;
761 }
762 
763 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
764   if (auto arg = v.dyn_cast<BlockArgument>()) {
765     unsigned argN = arg.getArgNumber();
766     // Any argument of the generic op that is not marked as a scalar
767     // argument is considered a tensor, indexed by the implicit loop
768     // bounds. This includes rank-0 tensor arguments.
769     if (arg.getOwner()->getParentOp() == op) {
770       OpOperand *t = op.getInputAndOutputOperands()[argN];
771       if (!op.isScalar(t))
772         return addExp(kTensor, argN);
773       v = t->get(); // get scalar value
774     }
775     // Any other argument (marked as scalar argument for the generic op
776     // or belonging to an enveloping op) is considered invariant.
777     return addExp(kInvariant, v);
778   }
779   // Something defined outside is invariant.
780   Operation *def = v.getDefiningOp();
781   if (def->getBlock() != &op.region().front())
782     return addExp(kInvariant, v);
783   // Construct index operations.
784   if (def->getNumOperands() == 0) {
785     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
786       return addExp(kIndex, indexOp.dim());
787   }
788   // Construct unary operations if subexpression can be built.
789   if (def->getNumOperands() == 1) {
790     auto x = buildTensorExp(op, def->getOperand(0));
791     if (x.hasValue()) {
792       unsigned e = x.getValue();
793       if (isa<math::AbsOp>(def))
794         return addExp(kAbsF, e);
795       if (isa<complex::AbsOp>(def))
796         return addExp(kAbsC, e);
797       if (isa<math::CeilOp>(def))
798         return addExp(kCeilF, e);
799       if (isa<math::FloorOp>(def))
800         return addExp(kFloorF, e);
801       if (isa<math::SqrtOp>(def))
802         return addExp(kSqrtF, e);
803       if (isa<complex::SqrtOp>(def))
804         return addExp(kSqrtC, e);
805       if (isa<math::ExpM1Op>(def))
806         return addExp(kExpm1F, e);
807       if (isa<complex::Expm1Op>(def))
808         return addExp(kExpm1C, e);
809       if (isa<math::Log1pOp>(def))
810         return addExp(kLog1pF, e);
811       if (isa<complex::Log1pOp>(def))
812         return addExp(kLog1pC, e);
813       if (isa<math::SinOp>(def))
814         return addExp(kSinF, e);
815       if (isa<complex::SinOp>(def))
816         return addExp(kSinC, e);
817       if (isa<math::TanhOp>(def))
818         return addExp(kTanhF, e);
819       if (isa<complex::TanhOp>(def))
820         return addExp(kTanhC, e);
821       if (isa<arith::NegFOp>(def))
822         return addExp(kNegF, e); // no negi in std
823       if (isa<complex::NegOp>(def))
824         return addExp(kNegC, e);
825       if (isa<arith::TruncFOp>(def))
826         return addExp(kTruncF, e, v);
827       if (isa<arith::ExtFOp>(def))
828         return addExp(kExtF, e, v);
829       if (isa<arith::FPToSIOp>(def))
830         return addExp(kCastFS, e, v);
831       if (isa<arith::FPToUIOp>(def))
832         return addExp(kCastFU, e, v);
833       if (isa<arith::SIToFPOp>(def))
834         return addExp(kCastSF, e, v);
835       if (isa<arith::UIToFPOp>(def))
836         return addExp(kCastUF, e, v);
837       if (isa<arith::ExtSIOp>(def))
838         return addExp(kCastS, e, v);
839       if (isa<arith::ExtUIOp>(def))
840         return addExp(kCastU, e, v);
841       if (isa<arith::IndexCastOp>(def))
842         return addExp(kCastIdx, e, v);
843       if (isa<arith::TruncIOp>(def))
844         return addExp(kTruncI, e, v);
845       if (isa<complex::ImOp>(def))
846         return addExp(kCIm, e);
847       if (isa<complex::ReOp>(def))
848         return addExp(kCRe, e);
849       if (isa<arith::BitcastOp>(def))
850         return addExp(kBitCast, e, v);
851       if (isa<sparse_tensor::UnaryOp>(def))
852         return addExp(kUnary, e, Value(), def);
853     }
854   }
855   // Construct binary operations if subexpressions can be built.
856   // See buildLattices() for an explanation of rejecting certain
857   // division and shift operations
858   if (def->getNumOperands() == 2) {
859     auto x = buildTensorExp(op, def->getOperand(0));
860     auto y = buildTensorExp(op, def->getOperand(1));
861     if (x.hasValue() && y.hasValue()) {
862       unsigned e0 = x.getValue();
863       unsigned e1 = y.getValue();
864       if (isa<arith::MulFOp>(def))
865         return addExp(kMulF, e0, e1);
866       if (isa<complex::MulOp>(def))
867         return addExp(kMulC, e0, e1);
868       if (isa<arith::MulIOp>(def))
869         return addExp(kMulI, e0, e1);
870       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
871         return addExp(kDivF, e0, e1);
872       if (isa<complex::DivOp>(def) && !maybeZero(e1))
873         return addExp(kDivC, e0, e1);
874       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
875         return addExp(kDivS, e0, e1);
876       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
877         return addExp(kDivU, e0, e1);
878       if (isa<arith::AddFOp>(def))
879         return addExp(kAddF, e0, e1);
880       if (isa<complex::AddOp>(def))
881         return addExp(kAddC, e0, e1);
882       if (isa<arith::AddIOp>(def))
883         return addExp(kAddI, e0, e1);
884       if (isa<arith::SubFOp>(def))
885         return addExp(kSubF, e0, e1);
886       if (isa<complex::SubOp>(def))
887         return addExp(kSubC, e0, e1);
888       if (isa<arith::SubIOp>(def))
889         return addExp(kSubI, e0, e1);
890       if (isa<arith::AndIOp>(def))
891         return addExp(kAndI, e0, e1);
892       if (isa<arith::OrIOp>(def))
893         return addExp(kOrI, e0, e1);
894       if (isa<arith::XOrIOp>(def))
895         return addExp(kXorI, e0, e1);
896       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
897         return addExp(kShrS, e0, e1);
898       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
899         return addExp(kShrU, e0, e1);
900       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
901         return addExp(kShlI, e0, e1);
902       if (isa<sparse_tensor::BinaryOp>(def))
903         return addExp(kBinary, e0, e1, Value(), def);
904     }
905   }
906   // Cannot build.
907   return None;
908 }
909 
910 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
911                            ValueRange vals) {
912   // Make a clone of overlap region.
913   Region tmpRegion;
914   BlockAndValueMapping mapper;
915   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
916   Block &clonedBlock = tmpRegion.front();
917   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
918   // Merge cloned block and return yield value.
919   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
920   rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
921   Value val = clonedYield.result();
922   rewriter.eraseOp(clonedYield);
923   rewriter.eraseOp(placeholder);
924   return val;
925 }
926 
927 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
928                                Operation *op, Value v0) {
929   if (!v0)
930     // Empty input value must be propagated.
931     return Value();
932   UnaryOp unop = cast<UnaryOp>(op);
933   Region &presentRegion = unop.presentRegion();
934   if (presentRegion.empty())
935     // Uninitialized Value() will be interpreted as missing data in the
936     // output.
937     return Value();
938   return insertYieldOp(rewriter, loc, presentRegion, {v0});
939 }
940 
941 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
942                                 Operation *op, Value v0, Value v1) {
943   if (!v0 || !v1)
944     // Empty input values must be propagated.
945     return Value();
946   BinaryOp binop = cast<BinaryOp>(op);
947   Region &overlapRegion = binop.overlapRegion();
948   if (overlapRegion.empty())
949     // Uninitialized Value() will be interpreted as missing data in the
950     // output.
951     return Value();
952   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
953 }
954 
955 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
956                        Value v0, Value v1) {
957   switch (tensorExps[e].kind) {
958   case kTensor:
959   case kInvariant:
960   case kIndex:
961     llvm_unreachable("unexpected non-op");
962   // Unary ops.
963   case kAbsF:
964     return rewriter.create<math::AbsOp>(loc, v0);
965   case kAbsC: {
966     auto type = v0.getType().template cast<ComplexType>();
967     auto eltType = type.getElementType().template cast<FloatType>();
968     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
969   }
970   case kCeilF:
971     return rewriter.create<math::CeilOp>(loc, v0);
972   case kFloorF:
973     return rewriter.create<math::FloorOp>(loc, v0);
974   case kSqrtF:
975     return rewriter.create<math::SqrtOp>(loc, v0);
976   case kSqrtC:
977     return rewriter.create<complex::SqrtOp>(loc, v0);
978   case kExpm1F:
979     return rewriter.create<math::ExpM1Op>(loc, v0);
980   case kExpm1C:
981     return rewriter.create<complex::Expm1Op>(loc, v0);
982   case kLog1pF:
983     return rewriter.create<math::Log1pOp>(loc, v0);
984   case kLog1pC:
985     return rewriter.create<complex::Log1pOp>(loc, v0);
986   case kSinF:
987     return rewriter.create<math::SinOp>(loc, v0);
988   case kSinC:
989     return rewriter.create<complex::SinOp>(loc, v0);
990   case kTanhF:
991     return rewriter.create<math::TanhOp>(loc, v0);
992   case kTanhC:
993     return rewriter.create<complex::TanhOp>(loc, v0);
994   case kNegF:
995     return rewriter.create<arith::NegFOp>(loc, v0);
996   case kNegC:
997     return rewriter.create<complex::NegOp>(loc, v0);
998   case kNegI: // no negi in std
999     return rewriter.create<arith::SubIOp>(
1000         loc,
1001         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1002                                            rewriter.getZeroAttr(v0.getType())),
1003         v0);
1004   case kTruncF:
1005     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1006   case kExtF:
1007     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1008   case kCastFS:
1009     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1010   case kCastFU:
1011     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1012   case kCastSF:
1013     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1014   case kCastUF:
1015     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1016   case kCastS:
1017     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1018   case kCastU:
1019     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1020   case kCastIdx:
1021     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1022   case kTruncI:
1023     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1024   case kCIm:
1025   case kCRe: {
1026     auto type = v0.getType().template cast<ComplexType>();
1027     auto eltType = type.getElementType().template cast<FloatType>();
1028     if (tensorExps[e].kind == kCIm)
1029       return rewriter.create<complex::ImOp>(loc, eltType, v0);
1030 
1031     return rewriter.create<complex::ReOp>(loc, eltType, v0);
1032   }
1033   case kBitCast:
1034     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1035   // Binary ops.
1036   case kMulF:
1037     return rewriter.create<arith::MulFOp>(loc, v0, v1);
1038   case kMulC:
1039     return rewriter.create<complex::MulOp>(loc, v0, v1);
1040   case kMulI:
1041     return rewriter.create<arith::MulIOp>(loc, v0, v1);
1042   case kDivF:
1043     return rewriter.create<arith::DivFOp>(loc, v0, v1);
1044   case kDivC:
1045     return rewriter.create<complex::DivOp>(loc, v0, v1);
1046   case kDivS:
1047     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1048   case kDivU:
1049     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1050   case kAddF:
1051     return rewriter.create<arith::AddFOp>(loc, v0, v1);
1052   case kAddC:
1053     return rewriter.create<complex::AddOp>(loc, v0, v1);
1054   case kAddI:
1055     return rewriter.create<arith::AddIOp>(loc, v0, v1);
1056   case kSubF:
1057     return rewriter.create<arith::SubFOp>(loc, v0, v1);
1058   case kSubC:
1059     return rewriter.create<complex::SubOp>(loc, v0, v1);
1060   case kSubI:
1061     return rewriter.create<arith::SubIOp>(loc, v0, v1);
1062   case kAndI:
1063     return rewriter.create<arith::AndIOp>(loc, v0, v1);
1064   case kOrI:
1065     return rewriter.create<arith::OrIOp>(loc, v0, v1);
1066   case kXorI:
1067     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1068   case kShrS:
1069     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1070   case kShrU:
1071     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1072   case kShlI:
1073     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1074   // Semiring ops with custom logic.
1075   case kBinaryBranch:
1076     return insertYieldOp(rewriter, loc,
1077                          *tensorExps[e].op->getBlock()->getParent(), {v0});
1078   case kUnary:
1079     return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
1080   case kBinary:
1081     return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
1082   }
1083   llvm_unreachable("unexpected expression kind in build");
1084 }
1085 
1086 } // namespace sparse_tensor
1087 } // namespace mlir
1088