1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 
11 #include "mlir/IR/Operation.h"
12 #include "llvm/Support/Debug.h"
13 
14 namespace mlir {
15 namespace sparse_tensor {
16 
17 //
18 // Constructors.
19 //
20 
21 TensorExp::TensorExp(Kind k, unsigned x, unsigned y, Value v)
22     : kind(k), val(v) {
23   switch (kind) {
24   case Kind::kTensor:
25     assert(x != -1u && y == -1u && !v);
26     tensor = x;
27     break;
28   case Kind::kInvariant:
29     assert(x == -1u && y == -1u && v);
30     break;
31   case Kind::kZero:
32     assert(x == -1u && y == -1u && !v);
33     break;
34   default:
35     assert(x != -1u && y != -1u && !v);
36     children.e0 = x;
37     children.e1 = y;
38     break;
39   }
40 }
41 
42 LatPoint::LatPoint(unsigned n, unsigned e, unsigned b)
43     : bits(n, false), simple(), exp(e) {
44   bits.set(b);
45 }
46 
47 LatPoint::LatPoint(const llvm::BitVector &b, unsigned e)
48     : bits(b), simple(), exp(e) {}
49 
50 //
51 // Lattice methods.
52 //
53 
54 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
55   unsigned e = tensorExps.size();
56   tensorExps.push_back(TensorExp(k, e0, e1, v));
57   return e;
58 }
59 
60 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
61   assert(t < numTensors && i < numLoops);
62   unsigned p = latPoints.size();
63   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
64   return p;
65 }
66 
67 unsigned Merger::addSet() {
68   unsigned s = latSets.size();
69   latSets.emplace_back(SmallVector<unsigned, 16>());
70   return s;
71 }
72 
73 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
74   unsigned p = latPoints.size();
75   llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
76   nb |= latPoints[p1].bits;
77   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
78   latPoints.push_back(LatPoint(nb, e));
79   return p;
80 }
81 
82 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
83   unsigned s = addSet();
84   for (unsigned p0 : latSets[s0])
85     for (unsigned p1 : latSets[s1])
86       latSets[s].push_back(conjLatPoint(kind, p0, p1));
87   return s;
88 }
89 
90 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
91   unsigned s = takeConj(kind, s0, s1);
92   // Followed by all in s0 and s1.
93   for (unsigned p : latSets[s0])
94     latSets[s].push_back(p);
95   if (Kind::kSubF <= kind && kind <= Kind::kSubI)
96     s1 = mapZero(kind, s1);
97   for (unsigned p : latSets[s1])
98     latSets[s].push_back(p);
99   return s;
100 }
101 
102 unsigned Merger::mapZero(Kind kind, unsigned s0) {
103   assert(Kind::kSubF <= kind && kind <= Kind::kSubI);
104   unsigned s = addSet();
105   unsigned z = addExp(Kind::kZero);
106   for (unsigned p : latSets[s0]) {
107     unsigned e = addExp(kind, z, latPoints[p].exp);
108     latPoints.push_back(LatPoint(latPoints[p].bits, e));
109     latSets[s].push_back(latPoints.size() - 1);
110   }
111   return s;
112 }
113 
114 unsigned Merger::optimizeSet(unsigned s0) {
115   unsigned s = addSet();
116   assert(latSets[s0].size() != 0);
117   unsigned p0 = latSets[s0][0];
118   for (unsigned p1 : latSets[s0]) {
119     bool add = true;
120     if (p0 != p1) {
121       // Is this a straightforward copy?
122       unsigned e = latPoints[p1].exp;
123       if (tensorExps[e].kind == Kind::kTensor &&
124           tensorExps[e].tensor == outTensor)
125         continue;
126       // Conjunction already covered?
127       for (unsigned p2 : latSets[s]) {
128         assert(!latGT(p1, p2)); // Lj => Li would be bad
129         if (onlyDenseDiff(p2, p1)) {
130           add = false;
131           break;
132         }
133       }
134       assert(!add || latGT(p0, p1));
135     }
136     if (add)
137       latSets[s].push_back(p1);
138   }
139   for (unsigned p : latSets[s])
140     latPoints[p].simple = simplifyCond(s, p);
141   return s;
142 }
143 
144 llvm::BitVector Merger::simplifyCond(unsigned s0, unsigned p0) {
145   // First determine if this lattice point is a *singleton*, i.e.,
146   // the last point in a lattice, no other is less than this one.
147   bool isSingleton = true;
148   for (unsigned p1 : latSets[s0]) {
149     if (p0 != p1 && latGT(p0, p1)) {
150       isSingleton = false;
151       break;
152     }
153   }
154   // Now apply the two basic rules.
155   llvm::BitVector simple = latPoints[p0].bits;
156   bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
157   for (unsigned b = 0, be = simple.size(); b < be; b++) {
158     if (simple[b] && !isDim(b, Dim::kSparse)) {
159       if (reset)
160         simple.reset(b);
161       reset = true;
162     }
163   }
164   return simple;
165 }
166 
167 bool Merger::latGT(unsigned i, unsigned j) const {
168   const llvm::BitVector &bitsi = latPoints[i].bits;
169   const llvm::BitVector &bitsj = latPoints[j].bits;
170   assert(bitsi.size() == bitsj.size());
171   if (bitsi.count() > bitsj.count()) {
172     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
173       if (bitsj[b] && !bitsi[b])
174         return false;
175     return true;
176   }
177   return false;
178 }
179 
180 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
181   llvm::BitVector tmp = latPoints[j].bits;
182   tmp ^= latPoints[i].bits;
183   return !hasAnyDimOf(tmp, Dim::kSparse);
184 }
185 
186 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
187   for (unsigned b = 0, be = bits.size(); b < be; b++)
188     if (bits[b] && isDim(b, d))
189       return true;
190   return false;
191 }
192 
193 #ifndef NDEBUG
194 
195 //
196 // Print methods (for debugging).
197 //
198 
199 static char kindToOpSymbol(Kind kind) {
200   switch (kind) {
201   case Kind::kMulF:
202   case Kind::kMulI:
203     return '*';
204   case Kind::kAddF:
205   case Kind::kAddI:
206     return '+';
207   case Kind::kSubF:
208   case Kind::kSubI:
209     return '-';
210   default:
211     break;
212   }
213   llvm_unreachable("unexpected kind");
214 }
215 
216 void Merger::dumpExp(unsigned e) const {
217   switch (tensorExps[e].kind) {
218   case Kind::kTensor:
219     if (tensorExps[e].tensor == syntheticTensor)
220       llvm::dbgs() << "synthetic_";
221     else if (tensorExps[e].tensor == outTensor)
222       llvm::dbgs() << "output_";
223     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
224     break;
225   case Kind::kInvariant:
226     llvm::dbgs() << "invariant";
227     break;
228   case Kind::kZero:
229     llvm::dbgs() << "zero";
230     break;
231   default:
232     llvm::dbgs() << "(";
233     dumpExp(tensorExps[e].children.e0);
234     llvm::dbgs() << " " << kindToOpSymbol(tensorExps[e].kind) << " ";
235     dumpExp(tensorExps[e].children.e1);
236     llvm::dbgs() << ")";
237   }
238 }
239 
240 void Merger::dumpLat(unsigned p) const {
241   llvm::dbgs() << "lat(";
242   dumpBits(latPoints[p].bits);
243   llvm::dbgs() << " :";
244   dumpBits(latPoints[p].simple);
245   llvm::dbgs() << " : ";
246   dumpExp(latPoints[p].exp);
247   llvm::dbgs() << " )\n";
248 }
249 
250 void Merger::dumpSet(unsigned s) const {
251   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
252   for (unsigned p : latSets[s]) {
253     llvm::dbgs() << "  ";
254     dumpLat(p);
255   }
256   llvm::dbgs() << "}\n";
257 }
258 
259 void Merger::dumpBits(const llvm::BitVector &bits) const {
260   for (unsigned b = 0, be = bits.size(); b < be; b++) {
261     if (bits[b]) {
262       unsigned t = tensor(b);
263       unsigned i = index(b);
264       llvm::dbgs() << " i_" << t << "_" << i << "_";
265       switch (dims[t][i]) {
266       case Dim::kSparse:
267         llvm::dbgs() << "S";
268         break;
269       case Dim::kDense:
270         llvm::dbgs() << "D";
271         break;
272       case Dim::kSingle:
273         llvm::dbgs() << "T";
274         break;
275       case Dim::kUndef:
276         llvm::dbgs() << "U";
277         break;
278       }
279     }
280   }
281 }
282 
283 #endif // NDEBUG
284 
285 //
286 // Builder methods.
287 //
288 
289 unsigned Merger::buildLattices(unsigned e, unsigned idx) {
290   Kind kind = tensorExps[e].kind;
291   switch (kind) {
292   case Kind::kTensor:
293   case Kind::kInvariant:
294   case Kind::kZero: {
295     // Either the index is really used in the tensor expression, or it is
296     // set to the undefined index in that dimension. An invariant expression
297     // is set to a synthetic tensor with undefined indices only.
298     unsigned s = addSet();
299     unsigned t = kind == Kind::kTensor ? tensorExps[e].tensor : syntheticTensor;
300     latSets[s].push_back(addLat(t, idx, e));
301     return s;
302   }
303   case Kind::kMulF:
304   case Kind::kMulI:
305     return takeConj(kind, // take binary conjunction
306                     buildLattices(tensorExps[e].children.e0, idx),
307                     buildLattices(tensorExps[e].children.e1, idx));
308   case Kind::kSubF:
309   case Kind::kSubI:
310     if (tensorExps[tensorExps[e].children.e0].kind == Kind::kZero)
311       return mapZero(kind, // maps to 0-y with just y's lattices
312                      buildLattices(tensorExps[e].children.e1, idx));
313     LLVM_FALLTHROUGH;
314   case Kind::kAddF:
315   case Kind::kAddI:
316     return takeDisj(kind, // take binary disjunction
317                     buildLattices(tensorExps[e].children.e0, idx),
318                     buildLattices(tensorExps[e].children.e1, idx));
319   }
320   llvm_unreachable("unexpected expression kind");
321 }
322 
323 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
324   Operation *yield = op.region().front().getTerminator();
325   return buildTensorExp(op, yield->getOperand(0));
326 }
327 
328 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
329   if (auto arg = val.dyn_cast<BlockArgument>()) {
330     unsigned argN = arg.getArgNumber();
331     // Any argument of the generic op that is not marked as a scalar
332     // argument is considered a tensor, indexed by the implicit loop
333     // bounds. This includes rank-0 tensor arguments.
334     if (arg.getOwner()->getParentOp() == op) {
335       OpOperand *t = op.getInputAndOutputOperands()[argN];
336       if (!op.isScalar(t))
337         return addExp(Kind::kTensor, argN);
338       val = t->get(); // get scalar value
339     }
340     // Any other argument (marked as scalar argument for the generic op
341     // or belonging to an enveloping op) is considered invariant.
342     return addExp(Kind::kInvariant, val);
343   }
344   // Something defined outside is invariant.
345   Operation *def = val.getDefiningOp();
346   if (def->getBlock() != &op.region().front())
347     return addExp(Kind::kInvariant, val);
348   // Construct unary operations if subexpression can be built.
349   if (def->getNumOperands() == 1) {
350     auto x = buildTensorExp(op, def->getOperand(0));
351     if (x.hasValue()) {
352       unsigned e0 = addExp(Kind::kZero);
353       unsigned e1 = x.getValue();
354       if (isa<NegFOp>(def))
355         return addExp(Kind::kSubF, e0, e1);
356       // TODO: no negi in std?
357     }
358   }
359   // Construct binary operations if subexpressions can be built.
360   if (def->getNumOperands() == 2) {
361     auto x = buildTensorExp(op, def->getOperand(0));
362     auto y = buildTensorExp(op, def->getOperand(1));
363     if (x.hasValue() && y.hasValue()) {
364       unsigned e0 = x.getValue();
365       unsigned e1 = y.getValue();
366       if (isa<MulFOp>(def))
367         return addExp(Kind::kMulF, e0, e1);
368       if (isa<MulIOp>(def))
369         return addExp(Kind::kMulI, e0, e1);
370       if (isa<AddFOp>(def))
371         return addExp(Kind::kAddF, e0, e1);
372       if (isa<AddIOp>(def))
373         return addExp(Kind::kAddI, e0, e1);
374       if (isa<SubFOp>(def))
375         return addExp(Kind::kSubF, e0, e1);
376       if (isa<SubIOp>(def))
377         return addExp(Kind::kSubI, e0, e1);
378     }
379   }
380   // Cannot build.
381   return None;
382 }
383 
384 } // namespace sparse_tensor
385 } // namespace mlir
386