1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 
11 #include "mlir/IR/Operation.h"
12 #include "llvm/Support/Debug.h"
13 
14 namespace mlir {
15 namespace sparse_tensor {
16 
17 //
18 // Constructors.
19 //
20 
21 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
22     : kind(k), val(v) {
23   switch (kind) {
24   case kTensor:
25     assert(x != -1u && y == -1u && !v);
26     tensor = x;
27     break;
28   case kInvariant:
29     assert(x == -1u && y == -1u && v);
30     break;
31   case kAbsF:
32   case kCeilF:
33   case kFloorF:
34   case kNegF:
35   case kNegI:
36     assert(x != -1u && y == -1u && !v);
37     children.e0 = x;
38     children.e1 = y;
39     break;
40   default:
41     assert(x != -1u && y != -1u && !v);
42     children.e0 = x;
43     children.e1 = y;
44     break;
45   }
46 }
47 
48 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
49     : bits(n, false), simple(), exp(e) {
50   bits.set(b);
51 }
52 
53 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
54     : bits(b), simple(), exp(e) {}
55 
56 //
57 // Lattice methods.
58 //
59 
60 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
61   unsigned e = tensorExps.size();
62   tensorExps.push_back(TensorExp(k, e0, e1, v));
63   return e;
64 }
65 
66 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
67   assert(t < numTensors && i < numLoops);
68   unsigned p = latPoints.size();
69   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
70   return p;
71 }
72 
73 unsigned Merger::addSet() {
74   unsigned s = latSets.size();
75   latSets.emplace_back(SmallVector<unsigned, 16>());
76   return s;
77 }
78 
79 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
80   unsigned p = latPoints.size();
81   llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
82   nb |= latPoints[p1].bits;
83   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
84   latPoints.push_back(LatPoint(nb, e));
85   return p;
86 }
87 
88 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
89   unsigned s = addSet();
90   for (unsigned p0 : latSets[s0])
91     for (unsigned p1 : latSets[s1])
92       latSets[s].push_back(conjLatPoint(kind, p0, p1));
93   return s;
94 }
95 
96 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
97   unsigned s = takeConj(kind, s0, s1);
98   // Followed by all in s0.
99   for (unsigned p : latSets[s0])
100     latSets[s].push_back(p);
101   // Map binary 0-y to unary -y.
102   if (kind == kSubF)
103     s1 = mapSet(kNegF, s1);
104   else if (kind == kSubI)
105     s1 = mapSet(kNegI, s1);
106   // Followed by all in s1.
107   for (unsigned p : latSets[s1])
108     latSets[s].push_back(p);
109   return s;
110 }
111 
112 unsigned Merger::mapSet(Kind kind, unsigned s0) {
113   assert(kAbsF <= kind && kind <= kNegI);
114   unsigned s = addSet();
115   for (unsigned p : latSets[s0]) {
116     unsigned e = addExp(kind, latPoints[p].exp);
117     latPoints.push_back(LatPoint(latPoints[p].bits, e));
118     latSets[s].push_back(latPoints.size() - 1);
119   }
120   return s;
121 }
122 
123 unsigned Merger::optimizeSet(unsigned s0) {
124   unsigned s = addSet();
125   assert(latSets[s0].size() != 0);
126   unsigned p0 = latSets[s0][0];
127   for (unsigned p1 : latSets[s0]) {
128     bool add = true;
129     if (p0 != p1) {
130       // Is this a straightforward copy?
131       unsigned e = latPoints[p1].exp;
132       if (tensorExps[e].kind == kTensor && tensorExps[e].tensor == outTensor)
133         continue;
134       // Conjunction already covered?
135       for (unsigned p2 : latSets[s]) {
136         assert(!latGT(p1, p2)); // Lj => Li would be bad
137         if (onlyDenseDiff(p2, p1)) {
138           add = false;
139           break;
140         }
141       }
142       assert(!add || latGT(p0, p1));
143     }
144     if (add)
145       latSets[s].push_back(p1);
146   }
147   for (unsigned p : latSets[s])
148     latPoints[p].simple = simplifyCond(s, p);
149   return s;
150 }
151 
152 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
153   // First determine if this lattice point is a *singleton*, i.e.,
154   // the last point in a lattice, no other is less than this one.
155   bool isSingleton = true;
156   for (unsigned p1 : latSets[s0]) {
157     if (p0 != p1 && latGT(p0, p1)) {
158       isSingleton = false;
159       break;
160     }
161   }
162   // Now apply the two basic rules.
163   llvm::BitVector simple = latPoints[p0].bits;
164   bool reset = isSingleton && hasAnyDimOf(simple, kSparse);
165   for (unsigned b = 0, be = simple.size(); b < be; b++) {
166     if (simple[b] && !isDim(b, kSparse)) {
167       if (reset)
168         simple.reset(b);
169       reset = true;
170     }
171   }
172   return simple;
173 }
174 
175 bool Merger::latGT(unsigned i, unsigned j) const {
176   const llvm::BitVector &bitsi = latPoints[i].bits;
177   const llvm::BitVector &bitsj = latPoints[j].bits;
178   assert(bitsi.size() == bitsj.size());
179   if (bitsi.count() > bitsj.count()) {
180     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
181       if (bitsj[b] && !bitsi[b])
182         return false;
183     return true;
184   }
185   return false;
186 }
187 
188 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
189   llvm::BitVector tmp = latPoints[j].bits;
190   tmp ^= latPoints[i].bits;
191   return !hasAnyDimOf(tmp, kSparse);
192 }
193 
194 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
195   for (unsigned b = 0, be = bits.size(); b < be; b++)
196     if (bits[b] && isDim(b, d))
197       return true;
198   return false;
199 }
200 
201 bool Merger::isConjunction(unsigned t, unsigned e) const {
202   switch (tensorExps[e].kind) {
203   case kTensor:
204     return tensorExps[e].tensor == t;
205   case kAbsF:
206   case kCeilF:
207   case kFloorF:
208   case kNegF:
209   case kNegI:
210     return isConjunction(t, tensorExps[e].children.e0);
211   case kDivF: // note: x / c only
212   case kDivS:
213   case kDivU:
214     assert(!maybeZero(tensorExps[e].children.e1));
215     return isConjunction(t, tensorExps[e].children.e0);
216   case kShrS: // note: x >> inv only
217   case kShrU:
218   case kShlI:
219     assert(isInvariant(tensorExps[e].children.e1));
220     return isConjunction(t, tensorExps[e].children.e0);
221   case kMulF:
222   case kMulI:
223   case kAndI:
224     return isConjunction(t, tensorExps[e].children.e0) ||
225            isConjunction(t, tensorExps[e].children.e1);
226   default:
227     return false;
228   }
229 }
230 
231 #ifndef NDEBUG
232 
233 //
234 // Print methods (for debugging).
235 //
236 
237 static const char *kindToOpSymbol(Kind kind) {
238   switch (kind) {
239   case kTensor:
240     return "tensor";
241   case kInvariant:
242     return "invariant";
243   case kAbsF:
244     return "abs";
245   case kCeilF:
246     return "ceil";
247   case kFloorF:
248     return "floor";
249   case kNegF:
250     return "-";
251   case kNegI:
252     return "-";
253   case kMulF:
254     return "*";
255   case kMulI:
256     return "*";
257   case kDivF:
258     return "/";
259   case kDivS:
260     return "/";
261   case kDivU:
262     return "/";
263   case kAddF:
264     return "+";
265   case kAddI:
266     return "+";
267   case kSubF:
268     return "-";
269   case kSubI:
270     return "-";
271   case kAndI:
272     return "&";
273   case kOrI:
274     return "|";
275   case kXorI:
276     return "^";
277   case kShrS:
278     return "a>>";
279   case kShrU:
280     return ">>";
281   case kShlI:
282     return "<<";
283   }
284   llvm_unreachable("unexpected kind for symbol");
285 }
286 
287 void Merger::dumpExp(unsigned e) const {
288   switch (tensorExps[e].kind) {
289   case kTensor:
290     if (tensorExps[e].tensor == syntheticTensor)
291       llvm::dbgs() << "synthetic_";
292     else if (tensorExps[e].tensor == outTensor)
293       llvm::dbgs() << "output_";
294     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
295     break;
296   case kInvariant:
297     llvm::dbgs() << "invariant";
298     break;
299   case kAbsF:
300   case kCeilF:
301   case kFloorF:
302   case kNegF:
303   case kNegI:
304     llvm::dbgs() << kindToOpSymbol(tensorExps[e].kind) << " ";
305     dumpExp(tensorExps[e].children.e0);
306     break;
307   default:
308     llvm::dbgs() << "(";
309     dumpExp(tensorExps[e].children.e0);
310     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
311     dumpExp(tensorExps[e].children.e1);
312     llvm::dbgs() << ")";
313   }
314 }
315 
316 void Merger::dumpLat(unsigned p) const {
317   llvm::dbgs() << "lat(";
318   dumpBits(latPoints[p].bits);
319   llvm::dbgs() << " :";
320   dumpBits(latPoints[p].simple);
321   llvm::dbgs() << " : ";
322   dumpExp(latPoints[p].exp);
323   llvm::dbgs() << " )\n";
324 }
325 
326 void Merger::dumpSet(unsigned s) const {
327   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
328   for (unsigned p : latSets[s]) {
329     llvm::dbgs() << "  ";
330     dumpLat(p);
331   }
332   llvm::dbgs() << "}\n";
333 }
334 
335 void Merger::dumpBits(const llvm::BitVector &bits) const {
336   for (unsigned b = 0, be = bits.size(); b < be; b++) {
337     if (bits[b]) {
338       unsigned t = tensor(b);
339       unsigned i = index(b);
340       llvm::dbgs() << " i_" << t << "_" << i << "_";
341       switch (dims[t][i]) {
342       case kSparse:
343         llvm::dbgs() << "S";
344         break;
345       case kDense:
346         llvm::dbgs() << "D";
347         break;
348       case kSingle:
349         llvm::dbgs() << "T";
350         break;
351       case kUndef:
352         llvm::dbgs() << "U";
353         break;
354       }
355     }
356   }
357 }
358 
359 #endif // NDEBUG
360 
361 //
362 // Builder methods.
363 //
364 
365 unsigned Merger::buildLattices(unsigned e, unsigned i) {
366   Kind kind = tensorExps[e].kind;
367   switch (kind) {
368   case kTensor:
369   case kInvariant: {
370     // Either the index is really used in the tensor expression, or it is
371     // set to the undefined index in that dimension. An invariant expression
372     // is set to a synthetic tensor with undefined indices only.
373     unsigned s = addSet();
374     unsigned t = kind == kTensor ? tensorExps[e].tensor : syntheticTensor;
375     latSets[s].push_back(addLat(t, i, e));
376     return s;
377   }
378   case kAbsF:
379   case kCeilF:
380   case kFloorF:
381   case kNegF:
382   case kNegI:
383     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
384     // lattice set of the operand through the operator into a new set.
385     //
386     //  -y|!y | y |
387     //  --+---+---+
388     //    | 0 |-y |
389     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i));
390   case kMulF:
391   case kMulI:
392   case kAndI:
393     // A multiplicative operation only needs to be performed
394     // for the conjunction of sparse iteration spaces.
395     //
396     //  x*y|!y | y |
397     //  ---+---+---+
398     //  !x | 0 | 0 |
399     //   x | 0 |x*y|
400     return takeConj(kind, // take binary conjunction
401                     buildLattices(tensorExps[e].children.e0, i),
402                     buildLattices(tensorExps[e].children.e1, i));
403   case kDivF:
404   case kDivS:
405   case kDivU:
406     // A division is tricky, since 0/0, 0/c, c/0 all have
407     // specific outcomes for floating-point and integers.
408     // Thus, we need to traverse the full iteration space.
409     //
410     //  x/y|!y | y |
411     //  ---+---+---+
412     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
413     //   x |x/0|x/y|  INT: x/0=exception for any x
414     //
415     // TODO: for now we "fixed" this by only accepting x/c cases
416     //       during expression building, so that the conjunction
417     //       rules applies (viz. x/c = x*(1/c) as far as lattice
418     //       construction is concerned).
419     assert(!maybeZero(tensorExps[e].children.e1));
420     return takeConj(kind, // take binary conjunction
421                     buildLattices(tensorExps[e].children.e0, i),
422                     buildLattices(tensorExps[e].children.e1, i));
423   case kAddF:
424   case kAddI:
425   case kSubF:
426   case kSubI:
427   case kOrI:
428   case kXorI:
429     // An additive operation needs to be performed
430     // for the disjunction of sparse iteration spaces.
431     //
432     //  x+y|!y | y |    x-y|!y | y |
433     //  ---+---+---+    ---+---+---+
434     //  !x | 0 | y |    !x | 0 |-y |
435     //   x | x |x+y|     x | x |x-y|
436     return takeDisj(kind, // take binary disjunction
437                     buildLattices(tensorExps[e].children.e0, i),
438                     buildLattices(tensorExps[e].children.e1, i));
439   case kShrS:
440   case kShrU:
441   case kShlI:
442     // A shift operation by an invariant amount (viz. tensor expressions
443     // can only occur at the left-hand-side of the operator) can be handled
444     // with the conjuction rule.
445     assert(isInvariant(tensorExps[e].children.e1));
446     return takeConj(kind, // take binary conjunction
447                     buildLattices(tensorExps[e].children.e0, i),
448                     buildLattices(tensorExps[e].children.e1, i));
449   }
450   llvm_unreachable("unexpected expression kind");
451 }
452 
453 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
454   Operation *yield = op.region().front().getTerminator();
455   return buildTensorExp(op, yield->getOperand(0));
456 }
457 
458 bool Merger::maybeZero(unsigned e) const {
459   if (tensorExps[e].kind == kInvariant) {
460     if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
461       return c.getValue() == 0;
462     if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>())
463       return c.getValue().isZero();
464   }
465   return true;
466 }
467 
468 bool Merger::isInvariant(unsigned e) const {
469   return tensorExps[e].kind == kInvariant;
470 }
471 
472 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
473   if (auto arg = v.dyn_cast<BlockArgument>()) {
474     unsigned argN = arg.getArgNumber();
475     // Any argument of the generic op that is not marked as a scalar
476     // argument is considered a tensor, indexed by the implicit loop
477     // bounds. This includes rank-0 tensor arguments.
478     if (arg.getOwner()->getParentOp() == op) {
479       OpOperand *t = op.getInputAndOutputOperands()[argN];
480       if (!op.isScalar(t))
481         return addExp(kTensor, argN);
482       v = t->get(); // get scalar value
483     }
484     // Any other argument (marked as scalar argument for the generic op
485     // or belonging to an enveloping op) is considered invariant.
486     return addExp(kInvariant, v);
487   }
488   // Something defined outside is invariant.
489   Operation *def = v.getDefiningOp();
490   if (def->getBlock() != &op.region().front())
491     return addExp(kInvariant, v);
492   // Construct unary operations if subexpression can be built.
493   if (def->getNumOperands() == 1) {
494     auto x = buildTensorExp(op, def->getOperand(0));
495     if (x.hasValue()) {
496       unsigned e = x.getValue();
497       if (isa<AbsFOp>(def))
498         return addExp(kAbsF, e);
499       if (isa<CeilFOp>(def))
500         return addExp(kCeilF, e);
501       if (isa<FloorFOp>(def))
502         return addExp(kFloorF, e);
503       if (isa<NegFOp>(def))
504         return addExp(kNegF, e);
505       // TODO: no negi in std?
506     }
507   }
508   // Construct binary operations if subexpressions can be built.
509   // TODO: see buildLattices() for an explanation of rejecting certain divisions
510   if (def->getNumOperands() == 2) {
511     auto x = buildTensorExp(op, def->getOperand(0));
512     auto y = buildTensorExp(op, def->getOperand(1));
513     if (x.hasValue() && y.hasValue()) {
514       unsigned e0 = x.getValue();
515       unsigned e1 = y.getValue();
516       if (isa<MulFOp>(def))
517         return addExp(kMulF, e0, e1);
518       if (isa<MulIOp>(def))
519         return addExp(kMulI, e0, e1);
520       if (isa<DivFOp>(def) && !maybeZero(e1))
521         return addExp(kDivF, e0, e1);
522       if (isa<SignedDivIOp>(def) && !maybeZero(e1))
523         return addExp(kDivS, e0, e1);
524       if (isa<UnsignedDivIOp>(def) && !maybeZero(e1))
525         return addExp(kDivU, e0, e1);
526       if (isa<AddFOp>(def))
527         return addExp(kAddF, e0, e1);
528       if (isa<AddIOp>(def))
529         return addExp(kAddI, e0, e1);
530       if (isa<SubFOp>(def))
531         return addExp(kSubF, e0, e1);
532       if (isa<SubIOp>(def))
533         return addExp(kSubI, e0, e1);
534       if (isa<AndOp>(def))
535         return addExp(kAndI, e0, e1);
536       if (isa<OrOp>(def))
537         return addExp(kOrI, e0, e1);
538       if (isa<XOrOp>(def))
539         return addExp(kXorI, e0, e1);
540       if (isa<SignedShiftRightOp>(def) && isInvariant(e1))
541         return addExp(kShrS, e0, e1);
542       if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1))
543         return addExp(kShrU, e0, e1);
544       if (isa<ShiftLeftOp>(def) && isInvariant(e1))
545         return addExp(kShlI, e0, e1);
546     }
547   }
548   // Cannot build.
549   return None;
550 }
551 
552 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
553                        Value v0, Value v1) {
554   switch (tensorExps[e].kind) {
555   case kTensor:
556   case kInvariant:
557     llvm_unreachable("unexpected non-op");
558   case kAbsF:
559     return rewriter.create<AbsFOp>(loc, v0);
560   case kCeilF:
561     return rewriter.create<CeilFOp>(loc, v0);
562   case kFloorF:
563     return rewriter.create<FloorFOp>(loc, v0);
564   case kNegF:
565     return rewriter.create<NegFOp>(loc, v0);
566   case kNegI:
567     assert(v1); // no negi in std
568     return rewriter.create<SubIOp>(loc, v0, v1);
569   case kMulF:
570     return rewriter.create<MulFOp>(loc, v0, v1);
571   case kMulI:
572     return rewriter.create<MulIOp>(loc, v0, v1);
573   case kDivF:
574     return rewriter.create<DivFOp>(loc, v0, v1);
575   case kDivS:
576     return rewriter.create<SignedDivIOp>(loc, v0, v1);
577   case kDivU:
578     return rewriter.create<UnsignedDivIOp>(loc, v0, v1);
579   case kAddF:
580     return rewriter.create<AddFOp>(loc, v0, v1);
581   case kAddI:
582     return rewriter.create<AddIOp>(loc, v0, v1);
583   case kSubF:
584     return rewriter.create<SubFOp>(loc, v0, v1);
585   case kSubI:
586     return rewriter.create<SubIOp>(loc, v0, v1);
587   case kAndI:
588     return rewriter.create<AndOp>(loc, v0, v1);
589   case kOrI:
590     return rewriter.create<OrOp>(loc, v0, v1);
591   case kXorI:
592     return rewriter.create<XOrOp>(loc, v0, v1);
593   case kShrS:
594     return rewriter.create<SignedShiftRightOp>(loc, v0, v1);
595   case kShrU:
596     return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1);
597   case kShlI:
598     return rewriter.create<ShiftLeftOp>(loc, v0, v1);
599   }
600   llvm_unreachable("unexpected expression kind in build");
601 }
602 
603 } // namespace sparse_tensor
604 } // namespace mlir
605