1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 
11 #include "mlir/IR/Operation.h"
12 #include "llvm/Support/Debug.h"
13 
14 namespace mlir {
15 namespace sparse_tensor {
16 
17 //
18 // Constructors.
19 //
20 
21 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
22     : kind(k), val(v) {
23   switch (kind) {
24   case Kind::kTensor:
25     assert(x != -1u && y == -1u && !v);
26     tensor = x;
27     break;
28   case Kind::kInvariant:
29     assert(x == -1u && y == -1u && v);
30     break;
31   case kAbsF:
32   case kCeilF:
33   case kFloorF:
34   case kNegF:
35   case kNegI:
36     assert(x != -1u && y == -1u && !v);
37     children.e0 = x;
38     children.e1 = y;
39     break;
40   default:
41     assert(x != -1u && y != -1u && !v);
42     children.e0 = x;
43     children.e1 = y;
44     break;
45   }
46 }
47 
48 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
49     : bits(n, false), simple(), exp(e) {
50   bits.set(b);
51 }
52 
53 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
54     : bits(b), simple(), exp(e) {}
55 
56 //
57 // Lattice methods.
58 //
59 
60 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
61   unsigned e = tensorExps.size();
62   tensorExps.push_back(TensorExp(k, e0, e1, v));
63   return e;
64 }
65 
66 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
67   assert(t < numTensors && i < numLoops);
68   unsigned p = latPoints.size();
69   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
70   return p;
71 }
72 
73 unsigned Merger::addSet() {
74   unsigned s = latSets.size();
75   latSets.emplace_back(SmallVector<unsigned, 16>());
76   return s;
77 }
78 
79 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
80   unsigned p = latPoints.size();
81   llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
82   nb |= latPoints[p1].bits;
83   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
84   latPoints.push_back(LatPoint(nb, e));
85   return p;
86 }
87 
88 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
89   unsigned s = addSet();
90   for (unsigned p0 : latSets[s0])
91     for (unsigned p1 : latSets[s1])
92       latSets[s].push_back(conjLatPoint(kind, p0, p1));
93   return s;
94 }
95 
96 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
97   unsigned s = takeConj(kind, s0, s1);
98   // Followed by all in s0.
99   for (unsigned p : latSets[s0])
100     latSets[s].push_back(p);
101   // Map binary 0-y to unary -y.
102   if (kind == Kind::kSubF)
103     s1 = mapSet(Kind::kNegF, s1);
104   else if (kind == Kind::kSubI)
105     s1 = mapSet(Kind::kNegI, s1);
106   // Followed by all in s1.
107   for (unsigned p : latSets[s1])
108     latSets[s].push_back(p);
109   return s;
110 }
111 
112 unsigned Merger::mapSet(Kind kind, unsigned s0) {
113   assert(Kind::kAbsF <= kind && kind <= Kind::kNegI);
114   unsigned s = addSet();
115   for (unsigned p : latSets[s0]) {
116     unsigned e = addExp(kind, latPoints[p].exp);
117     latPoints.push_back(LatPoint(latPoints[p].bits, e));
118     latSets[s].push_back(latPoints.size() - 1);
119   }
120   return s;
121 }
122 
123 unsigned Merger::optimizeSet(unsigned s0) {
124   unsigned s = addSet();
125   assert(latSets[s0].size() != 0);
126   unsigned p0 = latSets[s0][0];
127   for (unsigned p1 : latSets[s0]) {
128     bool add = true;
129     if (p0 != p1) {
130       // Is this a straightforward copy?
131       unsigned e = latPoints[p1].exp;
132       if (tensorExps[e].kind == Kind::kTensor &&
133           tensorExps[e].tensor == outTensor)
134         continue;
135       // Conjunction already covered?
136       for (unsigned p2 : latSets[s]) {
137         assert(!latGT(p1, p2)); // Lj => Li would be bad
138         if (onlyDenseDiff(p2, p1)) {
139           add = false;
140           break;
141         }
142       }
143       assert(!add || latGT(p0, p1));
144     }
145     if (add)
146       latSets[s].push_back(p1);
147   }
148   for (unsigned p : latSets[s])
149     latPoints[p].simple = simplifyCond(s, p);
150   return s;
151 }
152 
153 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
154   // First determine if this lattice point is a *singleton*, i.e.,
155   // the last point in a lattice, no other is less than this one.
156   bool isSingleton = true;
157   for (unsigned p1 : latSets[s0]) {
158     if (p0 != p1 && latGT(p0, p1)) {
159       isSingleton = false;
160       break;
161     }
162   }
163   // Now apply the two basic rules.
164   llvm::BitVector simple = latPoints[p0].bits;
165   bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
166   for (unsigned b = 0, be = simple.size(); b < be; b++) {
167     if (simple[b] && !isDim(b, Dim::kSparse)) {
168       if (reset)
169         simple.reset(b);
170       reset = true;
171     }
172   }
173   return simple;
174 }
175 
176 bool Merger::latGT(unsigned i, unsigned j) const {
177   const llvm::BitVector &bitsi = latPoints[i].bits;
178   const llvm::BitVector &bitsj = latPoints[j].bits;
179   assert(bitsi.size() == bitsj.size());
180   if (bitsi.count() > bitsj.count()) {
181     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
182       if (bitsj[b] && !bitsi[b])
183         return false;
184     return true;
185   }
186   return false;
187 }
188 
189 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
190   llvm::BitVector tmp = latPoints[j].bits;
191   tmp ^= latPoints[i].bits;
192   return !hasAnyDimOf(tmp, Dim::kSparse);
193 }
194 
195 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
196   for (unsigned b = 0, be = bits.size(); b < be; b++)
197     if (bits[b] && isDim(b, d))
198       return true;
199   return false;
200 }
201 
202 bool Merger::isConjunction(unsigned t, unsigned e) const {
203   switch (tensorExps[e].kind) {
204   case Kind::kTensor:
205     return tensorExps[e].tensor == t;
206   case kAbsF:
207   case kCeilF:
208   case kFloorF:
209   case kNegF:
210   case kNegI:
211   case Kind::kDivF: // note: x / c only
212   case Kind::kDivS:
213   case Kind::kDivU:
214   case Kind::kShrS: // note: x >> inv only
215   case Kind::kShrU:
216   case Kind::kShlI:
217     return isConjunction(t, tensorExps[e].children.e0);
218   case Kind::kMulF:
219   case Kind::kMulI:
220   case Kind::kAndI:
221     return isConjunction(t, tensorExps[e].children.e0) ||
222            isConjunction(t, tensorExps[e].children.e1);
223   default:
224     return false;
225   }
226 }
227 
228 #ifndef NDEBUG
229 
230 //
231 // Print methods (for debugging).
232 //
233 
234 static const char *kOpSymbols[] = {
235     "",  "",  "abs", "ceil", "floor", "-", "-", "*",   "*",  "/", "/",
236     "+", "+", "-",   "-",    "&",     "|", "^", "a>>", ">>", "<<"};
237 
238 void Merger::dumpExp(unsigned e) const {
239   switch (tensorExps[e].kind) {
240   case Kind::kTensor:
241     if (tensorExps[e].tensor == syntheticTensor)
242       llvm::dbgs() << "synthetic_";
243     else if (tensorExps[e].tensor == outTensor)
244       llvm::dbgs() << "output_";
245     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
246     break;
247   case Kind::kInvariant:
248     llvm::dbgs() << "invariant";
249     break;
250   case kAbsF:
251   case kCeilF:
252   case kFloorF:
253   case kNegF:
254   case kNegI:
255     llvm::dbgs() << kOpSymbols[tensorExps[e].kind] << " ";
256     dumpExp(tensorExps[e].children.e0);
257     break;
258   default:
259     llvm::dbgs() << "(";
260     dumpExp(tensorExps[e].children.e0);
261     llvm::dbgs() << " " << kOpSymbols[tensorExps[e].kind] << " ";
262     dumpExp(tensorExps[e].children.e1);
263     llvm::dbgs() << ")";
264   }
265 }
266 
267 void Merger::dumpLat(unsigned p) const {
268   llvm::dbgs() << "lat(";
269   dumpBits(latPoints[p].bits);
270   llvm::dbgs() << " :";
271   dumpBits(latPoints[p].simple);
272   llvm::dbgs() << " : ";
273   dumpExp(latPoints[p].exp);
274   llvm::dbgs() << " )\n";
275 }
276 
277 void Merger::dumpSet(unsigned s) const {
278   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
279   for (unsigned p : latSets[s]) {
280     llvm::dbgs() << "  ";
281     dumpLat(p);
282   }
283   llvm::dbgs() << "}\n";
284 }
285 
286 void Merger::dumpBits(const llvm::BitVector &bits) const {
287   for (unsigned b = 0, be = bits.size(); b < be; b++) {
288     if (bits[b]) {
289       unsigned t = tensor(b);
290       unsigned i = index(b);
291       llvm::dbgs() << " i_" << t << "_" << i << "_";
292       switch (dims[t][i]) {
293       case Dim::kSparse:
294         llvm::dbgs() << "S";
295         break;
296       case Dim::kDense:
297         llvm::dbgs() << "D";
298         break;
299       case Dim::kSingle:
300         llvm::dbgs() << "T";
301         break;
302       case Dim::kUndef:
303         llvm::dbgs() << "U";
304         break;
305       }
306     }
307   }
308 }
309 
310 #endif // NDEBUG
311 
312 //
313 // Builder methods.
314 //
315 
316 unsigned Merger::buildLattices(unsigned e, unsigned i) {
317   Kind kind = tensorExps[e].kind;
318   switch (kind) {
319   case Kind::kTensor:
320   case Kind::kInvariant: {
321     // Either the index is really used in the tensor expression, or it is
322     // set to the undefined index in that dimension. An invariant expression
323     // is set to a synthetic tensor with undefined indices only.
324     unsigned s = addSet();
325     unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor;
326     latSets[s].push_back(addLat(t, i, e));
327     return s;
328   }
329   case kAbsF:
330   case kCeilF:
331   case kFloorF:
332   case kNegF:
333   case kNegI:
334     // A zero preserving operation (viz. f(0) = 0, [Bik96,Ch5]) maps the
335     // lattice set of the operand through the operator into a new set.
336     //
337     //  -y|!y | y |
338     //  --+---+---+
339     //    | 0 |-y |
340     return mapSet(kind, buildLattices(tensorExps[e].children.e0, i));
341   case Kind::kMulF:
342   case Kind::kMulI:
343   case Kind::kAndI:
344     // A multiplicative operation only needs to be performed
345     // for the conjunction of sparse iteration spaces.
346     //
347     //  x*y|!y | y |
348     //  ---+---+---+
349     //  !x | 0 | 0 |
350     //   x | 0 |x*y|
351     return takeConj(kind, // take binary conjunction
352                     buildLattices(tensorExps[e].children.e0, i),
353                     buildLattices(tensorExps[e].children.e1, i));
354   case Kind::kDivF:
355   case Kind::kDivS:
356   case Kind::kDivU:
357     // A division is tricky, since 0/0, 0/c, c/0 all have
358     // specific outcomes for floating-point and integers.
359     // Thus, we need to traverse the full iteration space.
360     //
361     //  x/y|!y | y |
362     //  ---+---+---+
363     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
364     //   x |x/0|x/y|  INT: x/0=exception for any x
365     //
366     // TODO: for now we "fixed" this by only accepting x/c cases
367     //       during expression building, so that the conjunction
368     //       rules applies (viz. x/c = x*(1/c) as far as lattice
369     //       construction is concerned).
370     return takeConj(kind, // take binary conjunction
371                     buildLattices(tensorExps[e].children.e0, i),
372                     buildLattices(tensorExps[e].children.e1, i));
373   case Kind::kAddF:
374   case Kind::kAddI:
375   case Kind::kSubF:
376   case Kind::kSubI:
377   case Kind::kOrI:
378   case Kind::kXorI:
379     // An additive operation needs to be performed
380     // for the disjunction of sparse iteration spaces.
381     //
382     //  x+y|!y | y |    x-y|!y | y |
383     //  ---+---+---+    ---+---+---+
384     //  !x | 0 | y |    !x | 0 |-y |
385     //   x | x |x+y|     x | x |x-y|
386     return takeDisj(kind, // take binary disjunction
387                     buildLattices(tensorExps[e].children.e0, i),
388                     buildLattices(tensorExps[e].children.e1, i));
389   case Kind::kShrS:
390   case Kind::kShrU:
391   case Kind::kShlI:
392     // A shift operation by an invariant amount (viz. tensor expressions
393     // can only occur at the left-hand-side of the operator) can be handled
394     // with the conjuction rule.
395     return takeConj(kind, // take binary conjunction
396                     buildLattices(tensorExps[e].children.e0, i),
397                     buildLattices(tensorExps[e].children.e1, i));
398   }
399   llvm_unreachable("unexpected expression kind");
400 }
401 
402 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
403   Operation *yield = op.region().front().getTerminator();
404   return buildTensorExp(op, yield->getOperand(0));
405 }
406 
407 bool Merger::maybeZero(unsigned e) const {
408   if (tensorExps[e].kind == Kind::kInvariant) {
409     if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
410       return c.getValue() == 0;
411     if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>())
412       return c.getValue().isZero();
413   }
414   return true;
415 }
416 
417 bool Merger::isInvariant(unsigned e) const {
418   return tensorExps[e].kind == Kind::kInvariant;
419 }
420 
421 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
422   if (auto arg = v.dyn_cast<BlockArgument>()) {
423     unsigned argN = arg.getArgNumber();
424     // Any argument of the generic op that is not marked as a scalar
425     // argument is considered a tensor, indexed by the implicit loop
426     // bounds. This includes rank-0 tensor arguments.
427     if (arg.getOwner()->getParentOp() == op) {
428       OpOperand *t = op.getInputAndOutputOperands()[argN];
429       if (!op.isScalar(t))
430         return addExp(Kind::kTensor, argN);
431       v = t->get(); // get scalar value
432     }
433     // Any other argument (marked as scalar argument for the generic op
434     // or belonging to an enveloping op) is considered invariant.
435     return addExp(Kind::kInvariant, v);
436   }
437   // Something defined outside is invariant.
438   Operation *def = v.getDefiningOp();
439   if (def->getBlock() != &op.region().front())
440     return addExp(Kind::kInvariant, v);
441   // Construct unary operations if subexpression can be built.
442   if (def->getNumOperands() == 1) {
443     auto x = buildTensorExp(op, def->getOperand(0));
444     if (x.hasValue()) {
445       unsigned e = x.getValue();
446       if (isa<AbsFOp>(def))
447         return addExp(Kind::kAbsF, e);
448       if (isa<CeilFOp>(def))
449         return addExp(Kind::kCeilF, e);
450       if (isa<FloorFOp>(def))
451         return addExp(Kind::kFloorF, e);
452       if (isa<NegFOp>(def))
453         return addExp(Kind::kNegF, e);
454       // TODO: no negi in std?
455     }
456   }
457   // Construct binary operations if subexpressions can be built.
458   // TODO: see buildLattices() for an explanation of rejecting certain divisions
459   if (def->getNumOperands() == 2) {
460     auto x = buildTensorExp(op, def->getOperand(0));
461     auto y = buildTensorExp(op, def->getOperand(1));
462     if (x.hasValue() && y.hasValue()) {
463       unsigned e0 = x.getValue();
464       unsigned e1 = y.getValue();
465       if (isa<MulFOp>(def))
466         return addExp(Kind::kMulF, e0, e1);
467       if (isa<MulIOp>(def))
468         return addExp(Kind::kMulI, e0, e1);
469       if (isa<DivFOp>(def) && !maybeZero(e1))
470         return addExp(Kind::kDivF, e0, e1);
471       if (isa<SignedDivIOp>(def) && !maybeZero(e1))
472         return addExp(Kind::kDivS, e0, e1);
473       if (isa<UnsignedDivIOp>(def) && !maybeZero(e1))
474         return addExp(Kind::kDivU, e0, e1);
475       if (isa<AddFOp>(def))
476         return addExp(Kind::kAddF, e0, e1);
477       if (isa<AddIOp>(def))
478         return addExp(Kind::kAddI, e0, e1);
479       if (isa<SubFOp>(def))
480         return addExp(Kind::kSubF, e0, e1);
481       if (isa<SubIOp>(def))
482         return addExp(Kind::kSubI, e0, e1);
483       if (isa<AndOp>(def))
484         return addExp(Kind::kAndI, e0, e1);
485       if (isa<OrOp>(def))
486         return addExp(Kind::kOrI, e0, e1);
487       if (isa<XOrOp>(def))
488         return addExp(Kind::kXorI, e0, e1);
489       if (isa<SignedShiftRightOp>(def) && isInvariant(e1))
490         return addExp(Kind::kShrS, e0, e1);
491       if (isa<UnsignedShiftRightOp>(def) && isInvariant(e1))
492         return addExp(Kind::kShrU, e0, e1);
493       if (isa<ShiftLeftOp>(def) && isInvariant(e1))
494         return addExp(Kind::kShlI, e0, e1);
495     }
496   }
497   // Cannot build.
498   return None;
499 }
500 
501 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
502                        Value v0, Value v1) {
503   switch (tensorExps[e].kind) {
504   case Kind::kTensor:
505   case Kind::kInvariant:
506     llvm_unreachable("unexpected non-op");
507   case kAbsF:
508     return rewriter.create<AbsFOp>(loc, v0);
509   case kCeilF:
510     return rewriter.create<CeilFOp>(loc, v0);
511   case kFloorF:
512     return rewriter.create<FloorFOp>(loc, v0);
513   case kNegF:
514     return rewriter.create<NegFOp>(loc, v0);
515   case kNegI:
516     assert(v1); // no negi in std
517     return rewriter.create<SubIOp>(loc, v0, v1);
518   case Kind::kMulF:
519     return rewriter.create<MulFOp>(loc, v0, v1);
520   case Kind::kMulI:
521     return rewriter.create<MulIOp>(loc, v0, v1);
522   case Kind::kDivF:
523     return rewriter.create<DivFOp>(loc, v0, v1);
524   case Kind::kDivS:
525     return rewriter.create<SignedDivIOp>(loc, v0, v1);
526   case Kind::kDivU:
527     return rewriter.create<UnsignedDivIOp>(loc, v0, v1);
528   case Kind::kAddF:
529     return rewriter.create<AddFOp>(loc, v0, v1);
530   case Kind::kAddI:
531     return rewriter.create<AddIOp>(loc, v0, v1);
532   case Kind::kSubF:
533     return rewriter.create<SubFOp>(loc, v0, v1);
534   case Kind::kSubI:
535     return rewriter.create<SubIOp>(loc, v0, v1);
536   case Kind::kAndI:
537     return rewriter.create<AndOp>(loc, v0, v1);
538   case Kind::kOrI:
539     return rewriter.create<OrOp>(loc, v0, v1);
540   case Kind::kXorI:
541     return rewriter.create<XOrOp>(loc, v0, v1);
542   case Kind::kShrS:
543     return rewriter.create<SignedShiftRightOp>(loc, v0, v1);
544   case Kind::kShrU:
545     return rewriter.create<UnsignedShiftRightOp>(loc, v0, v1);
546   case Kind::kShlI:
547     return rewriter.create<ShiftLeftOp>(loc, v0, v1);
548   }
549   llvm_unreachable("unexpected expression kind in build");
550 }
551 
552 } // namespace sparse_tensor
553 } // namespace mlir
554