1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 
11 #include "mlir/IR/Operation.h"
12 #include "llvm/Support/Debug.h"
13 
14 namespace mlir {
15 namespace sparse_tensor {
16 
17 //
18 // Constructors.
19 //
20 
21 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
22     : kind(k), val(v) {
23   switch (kind) {
24   case Kind::kTensor:
25     assert(x != -1u && y == -1u && !v);
26     tensor = x;
27     break;
28   case Kind::kInvariant:
29     assert(x == -1u && y == -1u && v);
30     break;
31   case Kind::kZero:
32     assert(x == -1u && y == -1u && !v);
33     break;
34   default:
35     assert(x != -1u && y != -1u && !v);
36     children.e0 = x;
37     children.e1 = y;
38     break;
39   }
40 }
41 
42 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
43     : bits(n, false), simple(), exp(e) {
44   bits.set(b);
45 }
46 
47 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
48     : bits(b), simple(), exp(e) {}
49 
50 //
51 // Lattice methods.
52 //
53 
54 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
55   unsigned e = tensorExps.size();
56   tensorExps.push_back(TensorExp(k, e0, e1, v));
57   return e;
58 }
59 
60 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
61   assert(t < numTensors && i < numLoops);
62   unsigned p = latPoints.size();
63   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
64   return p;
65 }
66 
67 unsigned Merger::addSet() {
68   unsigned s = latSets.size();
69   latSets.emplace_back(SmallVector<unsigned, 16>());
70   return s;
71 }
72 
73 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
74   unsigned p = latPoints.size();
75   llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
76   nb |= latPoints[p1].bits;
77   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
78   latPoints.push_back(LatPoint(nb, e));
79   return p;
80 }
81 
82 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
83   unsigned s = addSet();
84   for (unsigned p0 : latSets[s0])
85     for (unsigned p1 : latSets[s1])
86       latSets[s].push_back(conjLatPoint(kind, p0, p1));
87   return s;
88 }
89 
90 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
91   unsigned s = takeConj(kind, s0, s1);
92   // Followed by all in s0 and s1.
93   for (unsigned p : latSets[s0])
94     latSets[s].push_back(p);
95   if (Kind::kSubF <= kind && kind <= Kind::kSubI)
96     s1 = mapZero(kind, s1);
97   for (unsigned p : latSets[s1])
98     latSets[s].push_back(p);
99   return s;
100 }
101 
102 unsigned Merger::mapZero(Kind kind, unsigned s0) {
103   assert(Kind::kSubF <= kind && kind <= Kind::kSubI);
104   unsigned s = addSet();
105   unsigned z = addExp(Kind::kZero);
106   for (unsigned p : latSets[s0]) {
107     unsigned e = addExp(kind, z, latPoints[p].exp);
108     latPoints.push_back(LatPoint(latPoints[p].bits, e));
109     latSets[s].push_back(latPoints.size() - 1);
110   }
111   return s;
112 }
113 
114 unsigned Merger::optimizeSet(unsigned s0) {
115   unsigned s = addSet();
116   assert(latSets[s0].size() != 0);
117   unsigned p0 = latSets[s0][0];
118   for (unsigned p1 : latSets[s0]) {
119     bool add = true;
120     if (p0 != p1) {
121       // Is this a straightforward copy?
122       unsigned e = latPoints[p1].exp;
123       if (tensorExps[e].kind == Kind::kTensor &&
124           tensorExps[e].tensor == outTensor)
125         continue;
126       // Conjunction already covered?
127       for (unsigned p2 : latSets[s]) {
128         assert(!latGT(p1, p2)); // Lj => Li would be bad
129         if (onlyDenseDiff(p2, p1)) {
130           add = false;
131           break;
132         }
133       }
134       assert(!add || latGT(p0, p1));
135     }
136     if (add)
137       latSets[s].push_back(p1);
138   }
139   for (unsigned p : latSets[s])
140     latPoints[p].simple = simplifyCond(s, p);
141   return s;
142 }
143 
144 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
145   // First determine if this lattice point is a *singleton*, i.e.,
146   // the last point in a lattice, no other is less than this one.
147   bool isSingleton = true;
148   for (unsigned p1 : latSets[s0]) {
149     if (p0 != p1 && latGT(p0, p1)) {
150       isSingleton = false;
151       break;
152     }
153   }
154   // Now apply the two basic rules.
155   llvm::BitVector simple = latPoints[p0].bits;
156   bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
157   for (unsigned b = 0, be = simple.size(); b < be; b++) {
158     if (simple[b] && !isDim(b, Dim::kSparse)) {
159       if (reset)
160         simple.reset(b);
161       reset = true;
162     }
163   }
164   return simple;
165 }
166 
167 bool Merger::latGT(unsigned i, unsigned j) const {
168   const llvm::BitVector &bitsi = latPoints[i].bits;
169   const llvm::BitVector &bitsj = latPoints[j].bits;
170   assert(bitsi.size() == bitsj.size());
171   if (bitsi.count() > bitsj.count()) {
172     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
173       if (bitsj[b] && !bitsi[b])
174         return false;
175     return true;
176   }
177   return false;
178 }
179 
180 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
181   llvm::BitVector tmp = latPoints[j].bits;
182   tmp ^= latPoints[i].bits;
183   return !hasAnyDimOf(tmp, Dim::kSparse);
184 }
185 
186 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
187   for (unsigned b = 0, be = bits.size(); b < be; b++)
188     if (bits[b] && isDim(b, d))
189       return true;
190   return false;
191 }
192 
193 bool Merger::isConjunction(unsigned t, unsigned e) const {
194   switch (tensorExps[e].kind) {
195   case Kind::kTensor:
196     return tensorExps[e].tensor == t;
197   case Kind::kMulF:
198   case Kind::kMulI:
199   case Kind::kAndI:
200   case Kind::kDivF: // note: x / c only
201   case Kind::kDivS:
202   case Kind::kDivU:
203     return isConjunction(t, tensorExps[e].children.e0) ||
204            isConjunction(t, tensorExps[e].children.e1);
205   default:
206     return false;
207   }
208 }
209 
210 #ifndef NDEBUG
211 
212 //
213 // Print methods (for debugging).
214 //
215 
216 static char kindToOpSymbol(Kind kind) {
217   switch (kind) {
218   case Kind::kMulF:
219   case Kind::kMulI:
220     return '*';
221   case Kind::kDivF:
222   case Kind::kDivS:
223   case Kind::kDivU:
224     return '/';
225   case Kind::kAddF:
226   case Kind::kAddI:
227     return '+';
228   case Kind::kSubF:
229   case Kind::kSubI:
230     return '-';
231   case Kind::kAndI:
232     return '&';
233   case Kind::kOrI:
234     return '|';
235   default:
236     break;
237   }
238   llvm_unreachable("unexpected kind");
239 }
240 
241 void Merger::dumpExp(unsigned e) const {
242   switch (tensorExps[e].kind) {
243   case Kind::kTensor:
244     if (tensorExps[e].tensor == syntheticTensor)
245       llvm::dbgs() << "synthetic_";
246     else if (tensorExps[e].tensor == outTensor)
247       llvm::dbgs() << "output_";
248     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
249     break;
250   case Kind::kInvariant:
251     llvm::dbgs() << "invariant";
252     break;
253   case Kind::kZero:
254     llvm::dbgs() << "zero";
255     break;
256   default:
257     llvm::dbgs() << "(";
258     dumpExp(tensorExps[e].children.e0);
259     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
260     dumpExp(tensorExps[e].children.e1);
261     llvm::dbgs() << ")";
262   }
263 }
264 
265 void Merger::dumpLat(unsigned p) const {
266   llvm::dbgs() << "lat(";
267   dumpBits(latPoints[p].bits);
268   llvm::dbgs() << " :";
269   dumpBits(latPoints[p].simple);
270   llvm::dbgs() << " : ";
271   dumpExp(latPoints[p].exp);
272   llvm::dbgs() << " )\n";
273 }
274 
275 void Merger::dumpSet(unsigned s) const {
276   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
277   for (unsigned p : latSets[s]) {
278     llvm::dbgs() << "  ";
279     dumpLat(p);
280   }
281   llvm::dbgs() << "}\n";
282 }
283 
284 void Merger::dumpBits(const llvm::BitVector &bits) const {
285   for (unsigned b = 0, be = bits.size(); b < be; b++) {
286     if (bits[b]) {
287       unsigned t = tensor(b);
288       unsigned i = index(b);
289       llvm::dbgs() << " i_" << t << "_" << i << "_";
290       switch (dims[t][i]) {
291       case Dim::kSparse:
292         llvm::dbgs() << "S";
293         break;
294       case Dim::kDense:
295         llvm::dbgs() << "D";
296         break;
297       case Dim::kSingle:
298         llvm::dbgs() << "T";
299         break;
300       case Dim::kUndef:
301         llvm::dbgs() << "U";
302         break;
303       }
304     }
305   }
306 }
307 
308 #endif // NDEBUG
309 
310 //
311 // Builder methods.
312 //
313 
314 unsigned Merger::buildLattices(unsigned e, unsigned i) {
315   Kind kind = tensorExps[e].kind;
316   switch (kind) {
317   case Kind::kTensor:
318   case Kind::kInvariant:
319   case Kind::kZero: {
320     // Either the index is really used in the tensor expression, or it is
321     // set to the undefined index in that dimension. An invariant expression
322     // is set to a synthetic tensor with undefined indices only.
323     unsigned s = addSet();
324     unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor;
325     latSets[s].push_back(addLat(t, i, e));
326     return s;
327   }
328   case Kind::kMulF:
329   case Kind::kMulI:
330   case Kind::kAndI:
331     // A multiplicative operation only needs to be performed
332     // for the conjunction of sparse iteration spaces.
333     //
334     //  x*y|!y | y |
335     //  ---+---+---+
336     //  !x | 0 | 0 |
337     //   x | 0 |x*y|
338     return takeConj(kind, // take binary conjunction
339                     buildLattices(tensorExps[e].children.e0, i),
340                     buildLattices(tensorExps[e].children.e1, i));
341   case Kind::kDivF:
342   case Kind::kDivS:
343   case Kind::kDivU:
344     // A division is tricky, since 0/0, 0/c, c/0 all have
345     // specific outcomes for floating-point and integers.
346     // Thus, we need to traverse the full iteration space.
347     //
348     //  x/y|!y | y |
349     //  ---+---+---+
350     //  !x |0/0|0/y|   FP: 0/0=NaN,c/0=Inf,0/c=0 with c true nonzero
351     //   x |x/0|x/y|  INT: x/0=exception for any x
352     //
353     // TODO: for now we "fixed" this by only accepting x/c cases
354     //       during expression building, so that the conjunction
355     //       rules applies (viz. x/c = x*(1/c) as far as lattice
356     //       construction is concerned).
357     return takeConj(kind, // take binary conjunction
358                     buildLattices(tensorExps[e].children.e0, i),
359                     buildLattices(tensorExps[e].children.e1, i));
360   case Kind::kSubF:
361   case Kind::kSubI:
362     // Special case: 0-y is -y.
363     if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero)
364       return mapZero(kind, // maps to 0-y with just y's lattices
365                      buildLattices(tensorExps[e].children.e1, i));
366     LLVM_FALLTHROUGH;
367   case Kind::kAddF:
368   case Kind::kAddI:
369   case Kind::kOrI:
370     // An additive operation needs to be performed
371     // for the disjunction of sparse iteration spaces.
372     //
373     //  x+y|!y | y |    x-y|!y | y |
374     //  ---+---+---+    ---+---+---+
375     //  !x | 0 | y |    !x | 0 |-y |
376     //   x | x |x+y|     x | x |x-y|
377     return takeDisj(kind, // take binary disjunction
378                     buildLattices(tensorExps[e].children.e0, i),
379                     buildLattices(tensorExps[e].children.e1, i));
380   }
381   llvm_unreachable("unexpected expression kind");
382 }
383 
384 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
385   Operation *yield = op.region().front().getTerminator();
386   return buildTensorExp(op, yield->getOperand(0));
387 }
388 
389 bool Merger::maybeZero(unsigned e) {
390   if (tensorExps[e].kind == Kind::kInvariant) {
391     if (auto c = tensorExps[e].val.getDefiningOp<ConstantIntOp>())
392       return c.getValue() == 0;
393     if (auto c = tensorExps[e].val.getDefiningOp<ConstantFloatOp>())
394       return c.getValue().isZero();
395   }
396   return true;
397 }
398 
399 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value v) {
400   if (auto arg = v.dyn_cast<BlockArgument>()) {
401     unsigned argN = arg.getArgNumber();
402     // Any argument of the generic op that is not marked as a scalar
403     // argument is considered a tensor, indexed by the implicit loop
404     // bounds. This includes rank-0 tensor arguments.
405     if (arg.getOwner()->getParentOp() == op) {
406       OpOperand *t = op.getInputAndOutputOperands()[argN];
407       if (!op.isScalar(t))
408         return addExp(Kind::kTensor, argN);
409       v = t->get(); // get scalar value
410     }
411     // Any other argument (marked as scalar argument for the generic op
412     // or belonging to an enveloping op) is considered invariant.
413     return addExp(Kind::kInvariant, v);
414   }
415   // Something defined outside is invariant.
416   Operation *def = v.getDefiningOp();
417   if (def->getBlock() != &op.region().front())
418     return addExp(Kind::kInvariant, v);
419   // Construct unary operations if subexpression can be built.
420   if (def->getNumOperands() == 1) {
421     auto x = buildTensorExp(op, def->getOperand(0));
422     if (x.hasValue()) {
423       unsigned e0 = addExp(Kind::kZero);
424       unsigned e1 = x.getValue();
425       if (isa<NegFOp>(def))
426         return addExp(Kind::kSubF, e0, e1);
427       // TODO: no negi in std?
428     }
429   }
430   // Construct binary operations if subexpressions can be built.
431   // TODO: see buildLattices() for an explanation of rejecting certain divisions
432   if (def->getNumOperands() == 2) {
433     auto x = buildTensorExp(op, def->getOperand(0));
434     auto y = buildTensorExp(op, def->getOperand(1));
435     if (x.hasValue() && y.hasValue()) {
436       unsigned e0 = x.getValue();
437       unsigned e1 = y.getValue();
438       if (isa<MulFOp>(def))
439         return addExp(Kind::kMulF, e0, e1);
440       if (isa<MulIOp>(def))
441         return addExp(Kind::kMulI, e0, e1);
442       if (isa<DivFOp>(def) && !maybeZero(e1))
443         return addExp(Kind::kDivF, e0, e1);
444       if (isa<SignedDivIOp>(def) && !maybeZero(e1))
445         return addExp(Kind::kDivS, e0, e1);
446       if (isa<UnsignedDivIOp>(def) && !maybeZero(e1))
447         return addExp(Kind::kDivU, e0, e1);
448       if (isa<AddFOp>(def))
449         return addExp(Kind::kAddF, e0, e1);
450       if (isa<AddIOp>(def))
451         return addExp(Kind::kAddI, e0, e1);
452       if (isa<SubFOp>(def))
453         return addExp(Kind::kSubF, e0, e1);
454       if (isa<SubIOp>(def))
455         return addExp(Kind::kSubI, e0, e1);
456       if (isa<AndOp>(def))
457         return addExp(Kind::kAndI, e0, e1);
458       if (isa<OrOp>(def))
459         return addExp(Kind::kOrI, e0, e1);
460     }
461   }
462   // Cannot build.
463   return None;
464 }
465 
466 Value Merger::buildExp(PatternRewriter &rewriter, Location loc, unsigned e,
467                        Value v0, Value v1) {
468   switch (tensorExps[e].kind) {
469   case Kind::kTensor:
470   case Kind::kInvariant:
471   case Kind::kZero:
472     llvm_unreachable("unexpected non-op");
473   case Kind::kMulF:
474     return rewriter.create<MulFOp>(loc, v0, v1);
475   case Kind::kMulI:
476     return rewriter.create<MulIOp>(loc, v0, v1);
477   case Kind::kDivF:
478     return rewriter.create<DivFOp>(loc, v0, v1);
479   case Kind::kDivS:
480     return rewriter.create<SignedDivIOp>(loc, v0, v1);
481   case Kind::kDivU:
482     return rewriter.create<UnsignedDivIOp>(loc, v0, v1);
483   case Kind::kAddF:
484     return rewriter.create<AddFOp>(loc, v0, v1);
485   case Kind::kAddI:
486     return rewriter.create<AddIOp>(loc, v0, v1);
487   case Kind::kSubF:
488     return rewriter.create<SubFOp>(loc, v0, v1);
489   case Kind::kSubI:
490     return rewriter.create<SubIOp>(loc, v0, v1);
491   case Kind::kAndI:
492     return rewriter.create<AndOp>(loc, v0, v1);
493   case Kind::kOrI:
494     return rewriter.create<OrOp>(loc, v0, v1);
495   }
496   llvm_unreachable("unexpected expression kind in build");
497 }
498 
499 } // namespace sparse_tensor
500 } // namespace mlir
501