1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
11 #include "mlir/Dialect/Complex/IR/Complex.h"
12 #include "mlir/Dialect/Math/IR/Math.h"
13 #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
14 
15 #include "mlir/IR/Operation.h"
16 #include "llvm/Support/Debug.h"
17 
18 namespace mlir {
19 namespace sparse_tensor {
20 
21 //===----------------------------------------------------------------------===//
22 // Constructors.
23 //===----------------------------------------------------------------------===//
24 
TensorExp(Kind k,unsigned x,unsigned y,Value v,Operation * o)25 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v, Operation *o)
26     : kind(k), val(v), op(o) {
27   switch (kind) {
28   // Leaf.
29   case kTensor:
30     assert(x != -1u && y == -1u && !v && !o);
31     tensor = x;
32     break;
33   case kInvariant:
34     assert(x == -1u && y == -1u && v && !o);
35     break;
36   case kIndex:
37     assert(x != -1u && y == -1u && !v && !o);
38     index = x;
39     break;
40   // Unary operations.
41   case kAbsF:
42   case kAbsC:
43   case kCeilF:
44   case kFloorF:
45   case kSqrtF:
46   case kSqrtC:
47   case kExpm1F:
48   case kExpm1C:
49   case kLog1pF:
50   case kLog1pC:
51   case kSinF:
52   case kSinC:
53   case kTanhF:
54   case kTanhC:
55   case kNegF:
56   case kNegC:
57   case kNegI:
58   case kCIm:
59   case kCRe:
60     assert(x != -1u && y == -1u && !v && !o);
61     children.e0 = x;
62     children.e1 = y;
63     break;
64   case kTruncF:
65   case kExtF:
66   case kCastFS:
67   case kCastFU:
68   case kCastSF:
69   case kCastUF:
70   case kCastS:
71   case kCastU:
72   case kCastIdx:
73   case kTruncI:
74   case kBitCast:
75     assert(x != -1u && y == -1u && v && !o);
76     children.e0 = x;
77     children.e1 = y;
78     break;
79   case kBinaryBranch:
80     assert(x != -1u && y == -1u && !v && o);
81     children.e0 = x;
82     children.e1 = y;
83     break;
84   case kUnary:
85     // No assertion on y can be made, as the branching paths involve both
86     // a unary (mapSet) and binary (takeDisj) pathway.
87     assert(x != -1u && !v && o);
88     children.e0 = x;
89     children.e1 = y;
90     break;
91   // Binary operations.
92   case kMulF:
93   case kMulC:
94   case kMulI:
95   case kDivF:
96   case kDivC:
97   case kDivS:
98   case kDivU:
99   case kAddF:
100   case kAddC:
101   case kAddI:
102   case kSubF:
103   case kSubC:
104   case kSubI:
105   case kAndI:
106   case kOrI:
107   case kXorI:
108   case kShrS:
109   case kShrU:
110   case kShlI:
111     assert(x != -1u && y != -1u && !v && !o);
112     children.e0 = x;
113     children.e1 = y;
114     break;
115   case kBinary:
116     assert(x != -1u && y != -1u && !v && o);
117     children.e0 = x;
118     children.e1 = y;
119     break;
120   }
121 }
122 
LatPoint(unsigned n,unsigned e,unsigned b)123 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
124     : bits(n, false), simple(), exp(e) {
125   bits.set(b);
126 }
127 
LatPoint(const BitVector & b,unsigned e)128 LatPoint::LatPoint(const BitVector &b, unsigned e)
129     : bits(b), simple(), exp(e) {}
130 
131 //===----------------------------------------------------------------------===//
132 // Lattice methods.
133 //===----------------------------------------------------------------------===//
134 
addExp(Kind k,unsigned e0,unsigned e1,Value v,Operation * op)135 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v,
136                         Operation *op) {
137   unsigned e = tensorExps.size();
138   tensorExps.push_back(TensorExp(k, e0, e1, v, op));
139   return e;
140 }
141 
addLat(unsigned t,unsigned i,unsigned e)142 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
143   assert(t < numTensors && i < numLoops);
144   unsigned p = latPoints.size();
145   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
146   return p;
147 }
148 
addSet()149 unsigned Merger::addSet() {
150   unsigned s = latSets.size();
151   latSets.emplace_back(SmallVector<unsigned, 16>());
152   return s;
153 }
154 
conjLatPoint(Kind kind,unsigned p0,unsigned p1,Operation * op)155 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1,
156                               Operation *op) {
157   unsigned p = latPoints.size();
158   BitVector nb = BitVector(latPoints[p0].bits);
159   nb |= latPoints[p1].bits;
160   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp, Value(), op);
161   latPoints.push_back(LatPoint(nb, e));
162   return p;
163 }
164 
takeConj(Kind kind,unsigned s0,unsigned s1,Operation * op)165 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
166   unsigned s = addSet();
167   for (unsigned p0 : latSets[s0])
168     for (unsigned p1 : latSets[s1])
169       latSets[s].push_back(conjLatPoint(kind, p0, p1, op));
170   return s;
171 }
172 
takeDisj(Kind kind,unsigned s0,unsigned s1,Operation * op)173 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1, Operation *op) {
174   unsigned s = takeConj(kind, s0, s1, op);
175   // Followed by all in s0.
176   for (unsigned p : latSets[s0])
177     latSets[s].push_back(p);
178   // Map binary 0-y to unary -y.
179   // TODO: move this if-else logic into buildLattices
180   if (kind == kSubF)
181     s1 = mapSet(kNegF, s1);
182   else if (kind == kSubC)
183     s1 = mapSet(kNegC, s1);
184   else if (kind == kSubI)
185     s1 = mapSet(kNegI, s1);
186   // Followed by all in s1.
187   for (unsigned p : latSets[s1])
188     latSets[s].push_back(p);
189   return s;
190 }
191 
takeCombi(Kind kind,unsigned s0,unsigned s1,Operation * orig,bool includeLeft,Kind ltrans,Operation * opleft,bool includeRight,Kind rtrans,Operation * opright)192 unsigned Merger::takeCombi(Kind kind, unsigned s0, unsigned s1, Operation *orig,
193                            bool includeLeft, Kind ltrans, Operation *opleft,
194                            bool includeRight, Kind rtrans, Operation *opright) {
195   unsigned s = takeConj(kind, s0, s1, orig);
196   // Left Region.
197   if (includeLeft) {
198     if (opleft)
199       s0 = mapSet(ltrans, s0, Value(), opleft);
200     for (unsigned p : latSets[s0])
201       latSets[s].push_back(p);
202   }
203   // Right Region.
204   if (includeRight) {
205     if (opright)
206       s1 = mapSet(rtrans, s1, Value(), opright);
207     for (unsigned p : latSets[s1])
208       latSets[s].push_back(p);
209   }
210   return s;
211 }
212 
mapSet(Kind kind,unsigned s0,Value v,Operation * op)213 unsigned Merger::mapSet(Kind kind, unsigned s0, Value v, Operation *op) {
214   assert(kAbsF <= kind && kind <= kUnary);
215   unsigned s = addSet();
216   for (unsigned p : latSets[s0]) {
217     unsigned e = addExp(kind, latPoints[p].exp, v, op);
218     latPoints.push_back(LatPoint(latPoints[p].bits, e));
219     latSets[s].push_back(latPoints.size() - 1);
220   }
221   return s;
222 }
223 
optimizeSet(unsigned s0)224 unsigned Merger::optimizeSet(unsigned s0) {
225   unsigned s = addSet();
226   assert(!latSets[s0].empty());
227   unsigned p0 = latSets[s0][0];
228   for (unsigned p1 : latSets[s0]) {
229     bool add = true;
230     if (p0 != p1) {
231       // Is this a straightforward copy?
232       unsigned e = latPoints[p1].exp;
233       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
234         continue;
235       // Conjunction already covered?
236       for (unsigned p2 : latSets[s]) {
237         assert(!latGT(p1, p2)); // Lj => Li would be bad
238         if (onlyDenseDiff(p2, p1)) {
239           add = false;
240           break;
241         }
242       }
243       assert(!add || latGT(p0, p1));
244     }
245     if (add)
246       latSets[s].push_back(p1);
247   }
248   for (unsigned p : latSets[s])
249     latPoints[p].simple = simplifyCond(s, p);
250   return s;
251 }
252 
simplifyCond(unsigned s0,unsigned p0)253 BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
254   // First determine if this lattice point is a *singleton*, i.e.,
255   // the last point in a lattice, no other is less than this one.
256   bool isSingleton = true;
257   for (unsigned p1 : latSets[s0]) {
258     if (p0 != p1 && latGT(p0, p1)) {
259       isSingleton = false;
260       break;
261     }
262   }
263   // Now apply the two basic rules.
264   BitVector simple = latPoints[p0].bits;
265   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
266   for (unsigned b = 0, be = simple.size(); b < be; b++) {
267     if (simple[b] && !isDim(b, kSparse)) {
268       if (reset)
269         simple.reset(b);
270       reset = true;
271     }
272   }
273   return simple;
274 }
275 
latGT(unsigned i,unsigned j) const276 bool Merger::latGT(unsigned i, unsigned j) const {
277   const BitVector &bitsi = latPoints[i].bits;
278   const BitVector &bitsj = latPoints[j].bits;
279   assert(bitsi.size() == bitsj.size());
280   if (bitsi.count() > bitsj.count()) {
281     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
282       if (bitsj[b] && !bitsi[b])
283         return false;
284     return true;
285   }
286   return false;
287 }
288 
onlyDenseDiff(unsigned i,unsigned j)289 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
290   BitVector tmp = latPoints[j].bits;
291   tmp ^= latPoints[i].bits;
292   return !hasAnyDimOf(tmp, kSparse);
293 }
294 
hasAnyDimOf(const BitVector & bits,Dim d) const295 bool Merger::hasAnyDimOf(const BitVector &bits, Dim d) const {
296   for (unsigned b = 0, be = bits.size(); b < be; b++)
297     if (bits[b] && isDim(b, d))
298       return true;
299   return false;
300 }
301 
isSingleCondition(unsigned t,unsigned e) const302 bool Merger::isSingleCondition(unsigned t, unsigned e) const {
303   switch (tensorExps[e].kind) {
304   // Leaf.
305   case kTensor:
306     return tensorExps[e].tensor == t;
307   case kInvariant:
308   case kIndex:
309     return false;
310   // Unary operations.
311   case kAbsF:
312   case kAbsC:
313   case kCeilF:
314   case kFloorF:
315   case kSqrtF:
316   case kSqrtC:
317   case kExpm1F:
318   case kExpm1C:
319   case kLog1pF:
320   case kLog1pC:
321   case kSinF:
322   case kSinC:
323   case kTanhF:
324   case kTanhC:
325   case kNegF:
326   case kNegC:
327   case kNegI:
328   case kTruncF:
329   case kExtF:
330   case kCastFS:
331   case kCastFU:
332   case kCastSF:
333   case kCastUF:
334   case kCastS:
335   case kCastU:
336   case kCastIdx:
337   case kTruncI:
338   case kCIm:
339   case kCRe:
340   case kBitCast:
341     return isSingleCondition(t, tensorExps[e].children.e0);
342   case kBinaryBranch:
343   case kUnary:
344     return false;
345   // Binary operations.
346   case kDivF: // note: x / c only
347   case kDivC:
348   case kDivS:
349   case kDivU:
350     assert(!maybeZero(tensorExps[e].children.e1));
351     return isSingleCondition(t, tensorExps[e].children.e0);
352   case kShrS: // note: x >> inv only
353   case kShrU:
354   case kShlI:
355     assert(isInvariant(tensorExps[e].children.e1));
356     return isSingleCondition(t, tensorExps[e].children.e0);
357   case kMulF:
358   case kMulC:
359   case kMulI:
360   case kAndI:
361     if (isSingleCondition(t, tensorExps[e].children.e0))
362       return isSingleCondition(t, tensorExps[e].children.e1) ||
363              isInvariant(tensorExps[e].children.e1);
364     if (isSingleCondition(t, tensorExps[e].children.e1))
365       return isInvariant(tensorExps[e].children.e0);
366     return false;
367   case kAddF:
368   case kAddC:
369   case kAddI:
370     return isSingleCondition(t, tensorExps[e].children.e0) &&
371            isSingleCondition(t, tensorExps[e].children.e1);
372   case kSubF:
373   case kSubC:
374   case kSubI:
375   case kOrI:
376   case kXorI:
377   case kBinary:
378     return false;
379   }
380   llvm_unreachable("unexpected kind");
381 }
382 
383 #ifndef NDEBUG
384 
385 //===----------------------------------------------------------------------===//
386 // Print methods (for debugging).
387 //===----------------------------------------------------------------------===//
388 
kindToOpSymbol(Kind kind)389 static const char *kindToOpSymbol(Kind kind) {
390   switch (kind) {
391   // Leaf.
392   case kTensor:
393     return "tensor";
394   case kInvariant:
395     return "invariant";
396   case kIndex:
397     return "index";
398   // Unary operations.
399   case kAbsF:
400   case kAbsC:
401     return "abs";
402   case kCeilF:
403     return "ceil";
404   case kFloorF:
405     return "floor";
406   case kSqrtF:
407   case kSqrtC:
408     return "sqrt";
409   case kExpm1F:
410   case kExpm1C:
411     return "expm1";
412   case kLog1pF:
413   case kLog1pC:
414     return "log1p";
415   case kSinF:
416   case kSinC:
417     return "sin";
418   case kTanhF:
419   case kTanhC:
420     return "tanh";
421   case kNegF:
422   case kNegC:
423   case kNegI:
424     return "-";
425   case kTruncF:
426   case kExtF:
427   case kCastFS:
428   case kCastFU:
429   case kCastSF:
430   case kCastUF:
431   case kCastS:
432   case kCastU:
433   case kCastIdx:
434   case kTruncI:
435   case kCIm:
436     return "complex.im";
437   case kCRe:
438     return "complex.re";
439   case kBitCast:
440     return "cast";
441   case kBinaryBranch:
442     return "binary_branch";
443   case kUnary:
444     return "unary";
445   // Binary operations.
446   case kMulF:
447   case kMulC:
448   case kMulI:
449     return "*";
450   case kDivF:
451   case kDivC:
452   case kDivS:
453   case kDivU:
454     return "/";
455   case kAddF:
456   case kAddC:
457   case kAddI:
458     return "+";
459   case kSubF:
460   case kSubC:
461   case kSubI:
462     return "-";
463   case kAndI:
464     return "&";
465   case kOrI:
466     return "|";
467   case kXorI:
468     return "^";
469   case kShrS:
470     return "a>>";
471   case kShrU:
472     return ">>";
473   case kShlI:
474     return "<<";
475   case kBinary:
476     return "binary";
477   }
478   llvm_unreachable("unexpected kind for symbol");
479 }
480 
dumpExp(unsigned e) const481 void Merger::dumpExp(unsigned e) const {
482   switch (tensorExps[e].kind) {
483   // Leaf.
484   case kTensor:
485     if (tensorExps[e].tensor == syntheticTensor)
486       llvm::dbgs() << "synthetic_";
487     else if (tensorExps[e].tensor == outTensor)
488       llvm::dbgs() << "output_";
489     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
490     break;
491   case kInvariant:
492     llvm::dbgs() << "invariant";
493     break;
494   case kIndex:
495     llvm::dbgs() << "index_" << tensorExps[e].index;
496     break;
497   // Unary operations.
498   case kAbsF:
499   case kAbsC:
500   case kCeilF:
501   case kFloorF:
502   case kSqrtF:
503   case kSqrtC:
504   case kExpm1F:
505   case kExpm1C:
506   case kLog1pF:
507   case kLog1pC:
508   case kSinF:
509   case kSinC:
510   case kTanhF:
511   case kTanhC:
512   case kNegF:
513   case kNegC:
514   case kNegI:
515   case kTruncF:
516   case kExtF:
517   case kCastFS:
518   case kCastFU:
519   case kCastSF:
520   case kCastUF:
521   case kCastS:
522   case kCastU:
523   case kCastIdx:
524   case kTruncI:
525   case kCIm:
526   case kCRe:
527   case kBitCast:
528   case kBinaryBranch:
529   case kUnary:
530     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
531     dumpExp(tensorExps[e].children.e0);
532     break;
533   // Binary operations.
534   case kMulF:
535   case kMulC:
536   case kMulI:
537   case kDivF:
538   case kDivC:
539   case kDivS:
540   case kDivU:
541   case kAddF:
542   case kAddC:
543   case kAddI:
544   case kSubF:
545   case kSubC:
546   case kSubI:
547   case kAndI:
548   case kOrI:
549   case kXorI:
550   case kShrS:
551   case kShrU:
552   case kShlI:
553   case kBinary:
554     llvm::dbgs() << "(";
555     dumpExp(tensorExps[e].children.e0);
556     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
557     dumpExp(tensorExps[e].children.e1);
558     llvm::dbgs() << ")";
559   }
560 }
561 
dumpLat(unsigned p) const562 void Merger::dumpLat(unsigned p) const {
563   llvm::dbgs() << "lat(";
564   dumpBits(latPoints[p].bits);
565   llvm::dbgs() << " :";
566   dumpBits(latPoints[p].simple);
567   llvm::dbgs() << " : ";
568   dumpExp(latPoints[p].exp);
569   llvm::dbgs() << " )\n";
570 }
571 
dumpSet(unsigned s) const572 void Merger::dumpSet(unsigned s) const {
573   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
574   for (unsigned p : latSets[s]) {
575     llvm::dbgs() << "  ";
576     dumpLat(p);
577   }
578   llvm::dbgs() << "}\n";
579 }
580 
dumpBits(const BitVector & bits) const581 void Merger::dumpBits(const BitVector &bits) const {
582   for (unsigned b = 0, be = bits.size(); b < be; b++) {
583     if (bits[b]) {
584       unsigned t = tensor(b);
585       unsigned i = index(b);
586       llvm::dbgs() << " i_" << t << "_" << i << "_";
587       switch (dims[t][i]) {
588       case kSparse:
589         llvm::dbgs() << "S";
590         break;
591       case kDense:
592         llvm::dbgs() << "D";
593         break;
594       case kSingle:
595         llvm::dbgs() << "T";
596         break;
597       case kUndef:
598         llvm::dbgs() << "U";
599         break;
600       }
601     }
602   }
603 }
604 
605 #endif // NDEBUG
606 
607 //===----------------------------------------------------------------------===//
608 // Builder methods.
609 //===----------------------------------------------------------------------===//
610 
buildLattices(unsigned e,unsigned i)611 unsigned Merger::buildLattices(unsigned e, unsigned i) {
612   Kind kind = tensorExps[e].kind;
613   switch (kind) {
614   // Leaf.
615   case kTensor:
616   case kInvariant:
617   case kIndex: {
618     // Either the index is really used in the tensor expression, or it is
619     // set to the undefined index in that dimension. An invariant expression,
620     // a proper index value, and a truly dynamic sparse output tensor are set
621     // to a synthetic tensor with undefined indices only to ensure the
622     // iteration space is not skipped as a result of their contents.
623     unsigned s = addSet();
624     unsigned t = syntheticTensor;
625     if (kind == kTensor) {
626       t = tensorExps[e].tensor;
627       if (hasSparseOut && t == outTensor)
628         t = syntheticTensor;
629     }
630     latSets[s].push_back(addLat(t, i, e));
631     return s;
632   }
633   // Unary operations.
634   case kAbsF:
635   case kAbsC:
636   case kCeilF:
637   case kFloorF:
638   case kSqrtF:
639   case kSqrtC:
640   case kExpm1F:
641   case kExpm1C:
642   case kLog1pF:
643   case kLog1pC:
644   case kSinF:
645   case kSinC:
646   case kTanhF:
647   case kTanhC:
648   case kNegF:
649   case kNegC:
650   case kNegI:
651   case kTruncF:
652   case kExtF:
653   case kCastFS:
654   case kCastFU:
655   case kCastSF:
656   case kCastUF:
657   case kCastS:
658   case kCastU:
659   case kCastIdx:
660   case kTruncI:
661   case kCIm:
662   case kCRe:
663   case kBitCast:
664     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
665     // lattice set of the operand through the operator into a new set.
666     //
667     //  -y|!y | y |
668     //  --+---+---+
669     //    | 0 |-y |
670     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i),
671                   tensorExps[e].val);
672   case kBinaryBranch:
673     // The left or right half of a binary operation which has already
674     // been split into separate operations for each region.
675     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i), Value(),
676                   tensorExps[e].op);
677   case kUnary:
678     // A custom unary operation.
679     //
680     //  op y|    !y    |     y      |
681     //  ----+----------+------------+
682     //      | absent() | present(y) |
683     {
684       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
685       UnaryOp unop = cast<UnaryOp>(tensorExps[e].op);
686       Region &absentRegion = unop.getAbsentRegion();
687 
688       if (absentRegion.empty()) {
689         // Simple mapping over existing values.
690         return mapSet(kind, child0, Value(), unop);
691       } // Use a disjunction with `unop` on the left and the absent value as an
692       // invariant on the right.
693       Block &absentBlock = absentRegion.front();
694       YieldOp absentYield = cast<YieldOp>(absentBlock.getTerminator());
695       Value absentVal = absentYield.getResult();
696       unsigned rhs = addExp(kInvariant, absentVal);
697       return takeDisj(kind, child0, buildLattices(rhs, i), unop);
698     }
699   // Binary operations.
700   case kMulF:
701   case kMulC:
702   case kMulI:
703   case kAndI:
704     // A multiplicative operation only needs to be performed
705     // for the conjunction of sparse iteration spaces.
706     //
707     //  x*y|!y | y |
708     //  ---+---+---+
709     //  !x | 0 | 0 |
710     //   x | 0 |x*y|
711     //
712     // Note even here, 0*NaN=NaN and 0*Inf=NaN, but that is ignored.
713     return takeConj(kind, // take binary conjunction
714                     buildLattices(tensorExps[e].children.e0, i),
715                     buildLattices(tensorExps[e].children.e1, i));
716   case kDivF:
717   case kDivC:
718   case kDivS:
719   case kDivU:
720     // A division is tricky, since 0/0, 0/c, c/0 all have
721     // specific outcomes for floating-point and integers.
722     // Thus, we need to traverse the full iteration space.
723     //
724     //  x/y|!y | y |
725     //  ---+---+---+
726     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
727     //   x |x/0|x/y|  INT: x/0=exception for any x
728     //
729     // TODO: for now we "fixed" this by only accepting x/c cases
730     //       during expression building, so that the conjunction
731     //       rules applies (viz. x/c = x*(1/c) as far as lattice
732     //       construction is concerned).
733     assert(!maybeZero(tensorExps[e].children.e1));
734     return takeConj(kind, // take binary conjunction
735                     buildLattices(tensorExps[e].children.e0, i),
736                     buildLattices(tensorExps[e].children.e1, i));
737   case kAddF:
738   case kAddC:
739   case kAddI:
740   case kSubF:
741   case kSubC:
742   case kSubI:
743   case kOrI:
744   case kXorI:
745     // An additive operation needs to be performed
746     // for the disjunction of sparse iteration spaces.
747     //
748     //  x+y|!y | y |    x-y|!y | y |
749     //  ---+---+---+    ---+---+---+
750     //  !x | 0 | y |    !x | 0 |-y |
751     //   x | x |x+y|     x | x |x-y|
752     return takeDisj(kind, // take binary disjunction
753                     buildLattices(tensorExps[e].children.e0, i),
754                     buildLattices(tensorExps[e].children.e1, i));
755   case kShrS:
756   case kShrU:
757   case kShlI:
758     // A shift operation by an invariant amount (viz. tensor expressions
759     // can only occur at the left-hand-side of the operator) can be handled
760     // with the conjuction rule.
761     assert(isInvariant(tensorExps[e].children.e1));
762     return takeConj(kind, // take binary conjunction
763                     buildLattices(tensorExps[e].children.e0, i),
764                     buildLattices(tensorExps[e].children.e1, i));
765   case kBinary:
766     // A custom binary operation.
767     //
768     //  x op y|   !y    |       y      |
769     //  ------+---------+--------------+
770     //    !x  |  empty  |   right(y)   |
771     //     x  | left(x) | overlap(x,y) |
772     {
773       unsigned child0 = buildLattices(tensorExps[e].children.e0, i);
774       unsigned child1 = buildLattices(tensorExps[e].children.e1, i);
775       BinaryOp binop = cast<BinaryOp>(tensorExps[e].op);
776       Region &leftRegion = binop.getLeftRegion();
777       Region &rightRegion = binop.getRightRegion();
778       // Left Region.
779       Operation *leftYield = nullptr;
780       if (!leftRegion.empty()) {
781         Block &leftBlock = leftRegion.front();
782         leftYield = leftBlock.getTerminator();
783       }
784       // Right Region.
785       Operation *rightYield = nullptr;
786       if (!rightRegion.empty()) {
787         Block &rightBlock = rightRegion.front();
788         rightYield = rightBlock.getTerminator();
789       }
790       bool includeLeft = binop.getLeftIdentity() || !leftRegion.empty();
791       bool includeRight = binop.getRightIdentity() || !rightRegion.empty();
792       return takeCombi(kBinary, child0, child1, binop, includeLeft,
793                        kBinaryBranch, leftYield, includeRight, kBinaryBranch,
794                        rightYield);
795     }
796   }
797   llvm_unreachable("unexpected expression kind");
798 }
799 
buildTensorExpFromLinalg(linalg::GenericOp op)800 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
801   // Build the linalg semantics backward from yield.
802   Operation *yield = op.region().front().getTerminator();
803   assert(isa<linalg::YieldOp>(yield));
804   return buildTensorExp(op, yield->getOperand(0));
805 }
806 
807 /// Only returns false if we are certain this is a nonzero.
maybeZero(unsigned e) const808 bool Merger::maybeZero(unsigned e) const {
809   if (tensorExps[e].kind == kInvariant) {
810     if (auto c = tensorExps[e].val.getDefiningOp<complex::ConstantOp>()) {
811       ArrayAttr arrayAttr = c.getValue();
812       return arrayAttr[0].cast<FloatAttr>().getValue().isZero() &&
813              arrayAttr[0].cast<FloatAttr>().getValue().isZero();
814     }
815     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantIntOp>())
816       return c.value() == 0;
817     if (auto c = tensorExps[e].val.getDefiningOp<arith::ConstantFloatOp>())
818       return c.value().isZero();
819   }
820   return true;
821 }
822 
isInvariant(unsigned e) const823 bool Merger::isInvariant(unsigned e) const {
824   return tensorExps[e].kind == kInvariant;
825 }
826 
inferType(unsigned e,Value src)827 Type Merger::inferType(unsigned e, Value src) {
828   // Obtain the destination type from the cast node.
829   Type dtp = tensorExps[e].val.getType();
830   // Inspect source type. For vector types, apply the same
831   // vectorization to the destination type.
832   if (auto vtp = src.getType().dyn_cast<VectorType>())
833     return VectorType::get(vtp.getNumElements(), dtp, vtp.getNumScalableDims());
834   return dtp;
835 }
836 
837 /// Ensures that sparse compiler can generate code for expression.
isAdmissableBranchExp(Operation * op,Block * block,Value v)838 static bool isAdmissableBranchExp(Operation *op, Block *block, Value v) {
839   // Arguments are always admissable.
840   if (auto arg = v.dyn_cast<BlockArgument>())
841     return true;
842   // Accept index anywhere.
843   Operation *def = v.getDefiningOp();
844   if (isa<linalg::IndexOp>(def))
845     return true;
846   // Operation defined outside branch.
847   if (def->getBlock() != block) {
848     return def->getBlock() != op->getBlock(); // invariant?
849   }
850   // Operation defined within branch. Anything is accepted,
851   // as long as all subexpressions are admissable.
852   for (unsigned i = 0, n = def->getNumOperands(); i < n; i++)
853     if (!isAdmissableBranchExp(op, block, def->getOperand(i)))
854       return false;
855   return true;
856 }
857 
858 /// Ensures that sparse compiler can generate code for branch.
isAdmissableBranch(Operation * op,Region & region)859 static bool isAdmissableBranch(Operation *op, Region &region) {
860   if (region.empty())
861     return true;
862   // Build the semi-ring branch semantics backward from yield.
863   Operation *yield = region.front().getTerminator();
864   assert(isa<YieldOp>(yield));
865   return isAdmissableBranchExp(op, &region.front(), yield->getOperand(0));
866 }
867 
buildTensorExp(linalg::GenericOp op,Value v)868 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
869   if (auto arg = v.dyn_cast<BlockArgument>()) {
870     unsigned argN = arg.getArgNumber();
871     // Any argument of the generic op that is not marked as a scalar
872     // argument is considered a tensor, indexed by the implicit loop
873     // bounds. This includes rank-0 tensor arguments.
874     if (arg.getOwner()->getParentOp() == op) {
875       OpOperand *t = op.getInputAndOutputOperands()[argN];
876       if (!op.isScalar(t))
877         return addExp(kTensor, argN);
878       v = t->get(); // get scalar value
879     }
880     // Any other argument (marked as scalar argument for the generic op
881     // or belonging to an enveloping op) is considered invariant.
882     return addExp(kInvariant, v);
883   }
884   // Something defined outside is invariant.
885   Operation *def = v.getDefiningOp();
886   if (def->getBlock() != &op.region().front())
887     return addExp(kInvariant, v);
888   // Construct index operations.
889   if (def->getNumOperands() == 0) {
890     if (auto indexOp = dyn_cast<linalg::IndexOp>(def))
891       return addExp(kIndex, indexOp.dim());
892   }
893   // Construct unary operations if subexpression can be built.
894   if (def->getNumOperands() == 1) {
895     auto x = buildTensorExp(op, def->getOperand(0));
896     if (x.has_value()) {
897       unsigned e = x.value();
898       if (isa<math::AbsOp>(def))
899         return addExp(kAbsF, e);
900       if (isa<complex::AbsOp>(def))
901         return addExp(kAbsC, e);
902       if (isa<math::CeilOp>(def))
903         return addExp(kCeilF, e);
904       if (isa<math::FloorOp>(def))
905         return addExp(kFloorF, e);
906       if (isa<math::SqrtOp>(def))
907         return addExp(kSqrtF, e);
908       if (isa<complex::SqrtOp>(def))
909         return addExp(kSqrtC, e);
910       if (isa<math::ExpM1Op>(def))
911         return addExp(kExpm1F, e);
912       if (isa<complex::Expm1Op>(def))
913         return addExp(kExpm1C, e);
914       if (isa<math::Log1pOp>(def))
915         return addExp(kLog1pF, e);
916       if (isa<complex::Log1pOp>(def))
917         return addExp(kLog1pC, e);
918       if (isa<math::SinOp>(def))
919         return addExp(kSinF, e);
920       if (isa<complex::SinOp>(def))
921         return addExp(kSinC, e);
922       if (isa<math::TanhOp>(def))
923         return addExp(kTanhF, e);
924       if (isa<complex::TanhOp>(def))
925         return addExp(kTanhC, e);
926       if (isa<arith::NegFOp>(def))
927         return addExp(kNegF, e); // no negi in std
928       if (isa<complex::NegOp>(def))
929         return addExp(kNegC, e);
930       if (isa<arith::TruncFOp>(def))
931         return addExp(kTruncF, e, v);
932       if (isa<arith::ExtFOp>(def))
933         return addExp(kExtF, e, v);
934       if (isa<arith::FPToSIOp>(def))
935         return addExp(kCastFS, e, v);
936       if (isa<arith::FPToUIOp>(def))
937         return addExp(kCastFU, e, v);
938       if (isa<arith::SIToFPOp>(def))
939         return addExp(kCastSF, e, v);
940       if (isa<arith::UIToFPOp>(def))
941         return addExp(kCastUF, e, v);
942       if (isa<arith::ExtSIOp>(def))
943         return addExp(kCastS, e, v);
944       if (isa<arith::ExtUIOp>(def))
945         return addExp(kCastU, e, v);
946       if (isa<arith::IndexCastOp>(def))
947         return addExp(kCastIdx, e, v);
948       if (isa<arith::TruncIOp>(def))
949         return addExp(kTruncI, e, v);
950       if (isa<complex::ImOp>(def))
951         return addExp(kCIm, e);
952       if (isa<complex::ReOp>(def))
953         return addExp(kCRe, e);
954       if (isa<arith::BitcastOp>(def))
955         return addExp(kBitCast, e, v);
956       if (auto unop = dyn_cast<sparse_tensor::UnaryOp>(def)) {
957         if (isAdmissableBranch(unop, unop.getPresentRegion()) &&
958             isAdmissableBranch(unop, unop.getAbsentRegion()))
959           return addExp(kUnary, e, Value(), def);
960       }
961     }
962   }
963   // Construct binary operations if subexpressions can be built.
964   // See buildLattices() for an explanation of rejecting certain
965   // division and shift operations
966   if (def->getNumOperands() == 2) {
967     auto x = buildTensorExp(op, def->getOperand(0));
968     auto y = buildTensorExp(op, def->getOperand(1));
969     if (x.has_value() && y.has_value()) {
970       unsigned e0 = x.value();
971       unsigned e1 = y.value();
972       if (isa<arith::MulFOp>(def))
973         return addExp(kMulF, e0, e1);
974       if (isa<complex::MulOp>(def))
975         return addExp(kMulC, e0, e1);
976       if (isa<arith::MulIOp>(def))
977         return addExp(kMulI, e0, e1);
978       if (isa<arith::DivFOp>(def) && !maybeZero(e1))
979         return addExp(kDivF, e0, e1);
980       if (isa<complex::DivOp>(def) && !maybeZero(e1))
981         return addExp(kDivC, e0, e1);
982       if (isa<arith::DivSIOp>(def) && !maybeZero(e1))
983         return addExp(kDivS, e0, e1);
984       if (isa<arith::DivUIOp>(def) && !maybeZero(e1))
985         return addExp(kDivU, e0, e1);
986       if (isa<arith::AddFOp>(def))
987         return addExp(kAddF, e0, e1);
988       if (isa<complex::AddOp>(def))
989         return addExp(kAddC, e0, e1);
990       if (isa<arith::AddIOp>(def))
991         return addExp(kAddI, e0, e1);
992       if (isa<arith::SubFOp>(def))
993         return addExp(kSubF, e0, e1);
994       if (isa<complex::SubOp>(def))
995         return addExp(kSubC, e0, e1);
996       if (isa<arith::SubIOp>(def))
997         return addExp(kSubI, e0, e1);
998       if (isa<arith::AndIOp>(def))
999         return addExp(kAndI, e0, e1);
1000       if (isa<arith::OrIOp>(def))
1001         return addExp(kOrI, e0, e1);
1002       if (isa<arith::XOrIOp>(def))
1003         return addExp(kXorI, e0, e1);
1004       if (isa<arith::ShRSIOp>(def) && isInvariant(e1))
1005         return addExp(kShrS, e0, e1);
1006       if (isa<arith::ShRUIOp>(def) && isInvariant(e1))
1007         return addExp(kShrU, e0, e1);
1008       if (isa<arith::ShLIOp>(def) && isInvariant(e1))
1009         return addExp(kShlI, e0, e1);
1010       if (auto binop = dyn_cast<sparse_tensor::BinaryOp>(def)) {
1011         if (isAdmissableBranch(binop, binop.getOverlapRegion()) &&
1012             (binop.getLeftIdentity() ||
1013              isAdmissableBranch(binop, binop.getLeftRegion())) &&
1014             (binop.getRightIdentity() ||
1015              isAdmissableBranch(binop, binop.getRightRegion())))
1016           return addExp(kBinary, e0, e1, Value(), def);
1017       }
1018     }
1019   }
1020   // Cannot build.
1021   return None;
1022 }
1023 
insertYieldOp(RewriterBase & rewriter,Location loc,Region & region,ValueRange vals)1024 static Value insertYieldOp(RewriterBase &rewriter, Location loc, Region &region,
1025                            ValueRange vals) {
1026   // Make a clone of overlap region.
1027   Region tmpRegion;
1028   BlockAndValueMapping mapper;
1029   region.cloneInto(&tmpRegion, tmpRegion.begin(), mapper);
1030   Block &clonedBlock = tmpRegion.front();
1031   YieldOp clonedYield = cast<YieldOp>(clonedBlock.getTerminator());
1032   // Merge cloned block and return yield value.
1033   Operation *placeholder = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1034   rewriter.mergeBlockBefore(&tmpRegion.front(), placeholder, vals);
1035   Value val = clonedYield.getResult();
1036   rewriter.eraseOp(clonedYield);
1037   rewriter.eraseOp(placeholder);
1038   return val;
1039 }
1040 
buildUnaryPresent(RewriterBase & rewriter,Location loc,Operation * op,Value v0)1041 static Value buildUnaryPresent(RewriterBase &rewriter, Location loc,
1042                                Operation *op, Value v0) {
1043   if (!v0)
1044     // Empty input value must be propagated.
1045     return Value();
1046   UnaryOp unop = cast<UnaryOp>(op);
1047   Region &presentRegion = unop.getPresentRegion();
1048   if (presentRegion.empty())
1049     // Uninitialized Value() will be interpreted as missing data in the
1050     // output.
1051     return Value();
1052   return insertYieldOp(rewriter, loc, presentRegion, {v0});
1053 }
1054 
buildBinaryOverlap(RewriterBase & rewriter,Location loc,Operation * op,Value v0,Value v1)1055 static Value buildBinaryOverlap(RewriterBase &rewriter, Location loc,
1056                                 Operation *op, Value v0, Value v1) {
1057   if (!v0 || !v1)
1058     // Empty input values must be propagated.
1059     return Value();
1060   BinaryOp binop = cast<BinaryOp>(op);
1061   Region &overlapRegion = binop.getOverlapRegion();
1062   if (overlapRegion.empty())
1063     // Uninitialized Value() will be interpreted as missing data in the
1064     // output.
1065     return Value();
1066   return insertYieldOp(rewriter, loc, overlapRegion, {v0, v1});
1067 }
1068 
buildExp(RewriterBase & rewriter,Location loc,unsigned e,Value v0,Value v1)1069 Value Merger::buildExp(RewriterBase &rewriter, Location loc, unsigned e,
1070                        Value v0, Value v1) {
1071   switch (tensorExps[e].kind) {
1072   // Leaf.
1073   case kTensor:
1074   case kInvariant:
1075   case kIndex:
1076     llvm_unreachable("unexpected non-op");
1077   // Unary operations.
1078   case kAbsF:
1079     return rewriter.create<math::AbsOp>(loc, v0);
1080   case kAbsC: {
1081     auto type = v0.getType().cast<ComplexType>();
1082     auto eltType = type.getElementType().cast<FloatType>();
1083     return rewriter.create<complex::AbsOp>(loc, eltType, v0);
1084   }
1085   case kCeilF:
1086     return rewriter.create<math::CeilOp>(loc, v0);
1087   case kFloorF:
1088     return rewriter.create<math::FloorOp>(loc, v0);
1089   case kSqrtF:
1090     return rewriter.create<math::SqrtOp>(loc, v0);
1091   case kSqrtC:
1092     return rewriter.create<complex::SqrtOp>(loc, v0);
1093   case kExpm1F:
1094     return rewriter.create<math::ExpM1Op>(loc, v0);
1095   case kExpm1C:
1096     return rewriter.create<complex::Expm1Op>(loc, v0);
1097   case kLog1pF:
1098     return rewriter.create<math::Log1pOp>(loc, v0);
1099   case kLog1pC:
1100     return rewriter.create<complex::Log1pOp>(loc, v0);
1101   case kSinF:
1102     return rewriter.create<math::SinOp>(loc, v0);
1103   case kSinC:
1104     return rewriter.create<complex::SinOp>(loc, v0);
1105   case kTanhF:
1106     return rewriter.create<math::TanhOp>(loc, v0);
1107   case kTanhC:
1108     return rewriter.create<complex::TanhOp>(loc, v0);
1109   case kNegF:
1110     return rewriter.create<arith::NegFOp>(loc, v0);
1111   case kNegC:
1112     return rewriter.create<complex::NegOp>(loc, v0);
1113   case kNegI: // no negi in std
1114     return rewriter.create<arith::SubIOp>(
1115         loc,
1116         rewriter.create<arith::ConstantOp>(loc, v0.getType(),
1117                                            rewriter.getZeroAttr(v0.getType())),
1118         v0);
1119   case kTruncF:
1120     return rewriter.create<arith::TruncFOp>(loc, inferType(e, v0), v0);
1121   case kExtF:
1122     return rewriter.create<arith::ExtFOp>(loc, inferType(e, v0), v0);
1123   case kCastFS:
1124     return rewriter.create<arith::FPToSIOp>(loc, inferType(e, v0), v0);
1125   case kCastFU:
1126     return rewriter.create<arith::FPToUIOp>(loc, inferType(e, v0), v0);
1127   case kCastSF:
1128     return rewriter.create<arith::SIToFPOp>(loc, inferType(e, v0), v0);
1129   case kCastUF:
1130     return rewriter.create<arith::UIToFPOp>(loc, inferType(e, v0), v0);
1131   case kCastS:
1132     return rewriter.create<arith::ExtSIOp>(loc, inferType(e, v0), v0);
1133   case kCastU:
1134     return rewriter.create<arith::ExtUIOp>(loc, inferType(e, v0), v0);
1135   case kCastIdx:
1136     return rewriter.create<arith::IndexCastOp>(loc, inferType(e, v0), v0);
1137   case kTruncI:
1138     return rewriter.create<arith::TruncIOp>(loc, inferType(e, v0), v0);
1139   case kCIm: {
1140     auto type = v0.getType().cast<ComplexType>();
1141     auto eltType = type.getElementType().cast<FloatType>();
1142     return rewriter.create<complex::ImOp>(loc, eltType, v0);
1143   }
1144   case kCRe: {
1145     auto type = v0.getType().cast<ComplexType>();
1146     auto eltType = type.getElementType().cast<FloatType>();
1147     return rewriter.create<complex::ReOp>(loc, eltType, v0);
1148   }
1149   case kBitCast:
1150     return rewriter.create<arith::BitcastOp>(loc, inferType(e, v0), v0);
1151   // Binary operations.
1152   case kMulF:
1153     return rewriter.create<arith::MulFOp>(loc, v0, v1);
1154   case kMulC:
1155     return rewriter.create<complex::MulOp>(loc, v0, v1);
1156   case kMulI:
1157     return rewriter.create<arith::MulIOp>(loc, v0, v1);
1158   case kDivF:
1159     return rewriter.create<arith::DivFOp>(loc, v0, v1);
1160   case kDivC:
1161     return rewriter.create<complex::DivOp>(loc, v0, v1);
1162   case kDivS:
1163     return rewriter.create<arith::DivSIOp>(loc, v0, v1);
1164   case kDivU:
1165     return rewriter.create<arith::DivUIOp>(loc, v0, v1);
1166   case kAddF:
1167     return rewriter.create<arith::AddFOp>(loc, v0, v1);
1168   case kAddC:
1169     return rewriter.create<complex::AddOp>(loc, v0, v1);
1170   case kAddI:
1171     return rewriter.create<arith::AddIOp>(loc, v0, v1);
1172   case kSubF:
1173     return rewriter.create<arith::SubFOp>(loc, v0, v1);
1174   case kSubC:
1175     return rewriter.create<complex::SubOp>(loc, v0, v1);
1176   case kSubI:
1177     return rewriter.create<arith::SubIOp>(loc, v0, v1);
1178   case kAndI:
1179     return rewriter.create<arith::AndIOp>(loc, v0, v1);
1180   case kOrI:
1181     return rewriter.create<arith::OrIOp>(loc, v0, v1);
1182   case kXorI:
1183     return rewriter.create<arith::XOrIOp>(loc, v0, v1);
1184   case kShrS:
1185     return rewriter.create<arith::ShRSIOp>(loc, v0, v1);
1186   case kShrU:
1187     return rewriter.create<arith::ShRUIOp>(loc, v0, v1);
1188   case kShlI:
1189     return rewriter.create<arith::ShLIOp>(loc, v0, v1);
1190   case kBinaryBranch: // semi-ring ops with custom logic.
1191     return insertYieldOp(rewriter, loc,
1192                          *tensorExps[e].op->getBlock()->getParent(), {v0});
1193   case kUnary:
1194     return buildUnaryPresent(rewriter, loc, tensorExps[e].op, v0);
1195   case kBinary:
1196     return buildBinaryOverlap(rewriter, loc, tensorExps[e].op, v0, v1);
1197   }
1198   llvm_unreachable("unexpected expression kind in build");
1199 }
1200 
1201 } // namespace sparse_tensor
1202 } // namespace mlir
1203