1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 
11 #include "mlir/IR/Operation.h"
12 #include "llvm/Support/Debug.h"
13 
14 namespace mlir {
15 namespace sparse_tensor {
16 
17 //
18 // Lattice methods.
19 //
20 
21 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
22   unsigned e = tensorExps.size();
23   tensorExps.push_back(TensorExp(k, e0, e1, v));
24   return e;
25 }
26 
27 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
28   assert(t < numTensors && i < numLoops);
29   unsigned p = latPoints.size();
30   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
31   return p;
32 }
33 
34 unsigned Merger::addSet() {
35   unsigned s = latSets.size();
36   latSets.emplace_back(SmallVector<unsigned, 16>());
37   return s;
38 }
39 
40 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
41   unsigned p = latPoints.size();
42   llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
43   nb |= latPoints[p1].bits;
44   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
45   latPoints.push_back(LatPoint(nb, e));
46   return p;
47 }
48 
49 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
50   unsigned s = addSet();
51   for (unsigned p0 : latSets[s0])
52     for (unsigned p1 : latSets[s1])
53       latSets[s].push_back(conjLatPoint(kind, p0, p1));
54   return s;
55 }
56 
57 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
58   unsigned s = takeConj(kind, s0, s1);
59   for (unsigned p : latSets[s0])
60     latSets[s].push_back(p);
61   for (unsigned p : latSets[s1])
62     latSets[s].push_back(p);
63   return s;
64 }
65 
66 unsigned Merger::optimizeSet(unsigned s0) {
67   unsigned s = addSet();
68   assert(latSets[s0].size() != 0);
69   unsigned p0 = latSets[s0][0];
70   for (unsigned p1 : latSets[s0]) {
71     bool add = true;
72     if (p0 != p1) {
73       // Is this a straightforward copy?
74       unsigned e = latPoints[p1].exp;
75       if (tensorExps[e].kind == Kind::kTensor && tensorExps[e].e0 == outTensor)
76         continue;
77       // Conjunction already covered?
78       for (unsigned p2 : latSets[s]) {
79         assert(!latGT(p1, p2)); // Lj => Li would be bad
80         if (onlyDenseDiff(p2, p1)) {
81           add = false;
82           break;
83         }
84       }
85       assert(!add || latGT(p0, p1));
86     }
87     if (add)
88       latSets[s].push_back(p1);
89   }
90   for (unsigned p : latSets[s])
91     latPoints[p].simple = simplifyCond(s, p);
92   return s;
93 }
94 
95 llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
96   // First determine if this lattice point is a *singleton*, i.e.,
97   // the last point in a lattice, no other is less than this one.
98   bool isSingleton = true;
99   for (unsigned p1 : latSets[s]) {
100     if (p0 != p1 && latGT(p0, p1)) {
101       isSingleton = false;
102       break;
103     }
104   }
105   // Now apply the two basic rules.
106   llvm::BitVector simple = latPoints[p0].bits;
107   bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
108   for (unsigned b = 0, be = simple.size(); b < be; b++) {
109     if (simple[b] && !isDim(b, Dim::kSparse)) {
110       if (reset)
111         simple.reset(b);
112       reset = true;
113     }
114   }
115   return simple;
116 }
117 
118 bool Merger::latGT(unsigned i, unsigned j) const {
119   const llvm::BitVector &bitsi = latPoints[i].bits;
120   const llvm::BitVector &bitsj = latPoints[j].bits;
121   assert(bitsi.size() == bitsj.size());
122   if (bitsi.count() > bitsj.count()) {
123     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
124       if (bitsj[b] && !bitsi[b])
125         return false;
126     return true;
127   }
128   return false;
129 }
130 
131 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
132   llvm::BitVector tmp = latPoints[j].bits;
133   tmp ^= latPoints[i].bits;
134   return !hasAnyDimOf(tmp, Dim::kSparse);
135 }
136 
137 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
138   for (unsigned b = 0, be = bits.size(); b < be; b++)
139     if (bits[b] && isDim(b, d))
140       return true;
141   return false;
142 }
143 
144 #ifndef NDEBUG
145 
146 //
147 // Print methods (for debugging).
148 //
149 
150 void Merger::dumpExp(unsigned e) const {
151   switch (tensorExps[e].kind) {
152   case Kind::kTensor:
153     if (tensorExps[e].e0 == syntheticTensor)
154       llvm::dbgs() << "synthetic_";
155     else if (tensorExps[e].e0 == outTensor)
156       llvm::dbgs() << "output_";
157     llvm::dbgs() << "tensor_" << tensorExps[e].e0;
158     break;
159   case Kind::kInvariant:
160     llvm::dbgs() << "invariant";
161     break;
162   default:
163   case Kind::kMulI:
164     llvm::dbgs() << "(";
165     dumpExp(tensorExps[e].e0);
166     llvm::dbgs() << " * ";
167     dumpExp(tensorExps[e].e1);
168     llvm::dbgs() << ")";
169     break;
170   case Kind::kAddF:
171   case Kind::kAddI:
172     llvm::dbgs() << "(";
173     dumpExp(tensorExps[e].e0);
174     llvm::dbgs() << " + ";
175     dumpExp(tensorExps[e].e1);
176     llvm::dbgs() << ")";
177     break;
178   }
179 }
180 
181 void Merger::dumpLat(unsigned p) const {
182   llvm::dbgs() << "lat(";
183   dumpBits(latPoints[p].bits);
184   llvm::dbgs() << " :";
185   dumpBits(latPoints[p].simple);
186   llvm::dbgs() << " / ";
187   dumpExp(latPoints[p].exp);
188   llvm::dbgs() << " )\n";
189 }
190 
191 void Merger::dumpSet(unsigned s) const {
192   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
193   for (unsigned p : latSets[s]) {
194     llvm::dbgs() << "  ";
195     dumpLat(p);
196   }
197   llvm::dbgs() << "}\n";
198 }
199 
200 void Merger::dumpBits(const llvm::BitVector &bits) const {
201   for (unsigned b = 0, be = bits.size(); b < be; b++) {
202     if (bits[b]) {
203       unsigned t = tensor(b);
204       unsigned i = index(b);
205       llvm::dbgs() << " i_" << t << "_" << i << "_";
206       switch (dims[t][i]) {
207       case Dim::kSparse:
208         llvm::dbgs() << "S";
209         break;
210       case Dim::kDense:
211         llvm::dbgs() << "D";
212         break;
213       case Dim::kSingle:
214         llvm::dbgs() << "T";
215         break;
216       case Dim::kUndef:
217         llvm::dbgs() << "U";
218         break;
219       }
220     }
221   }
222 }
223 
224 #endif // NDEBUG
225 
226 //
227 // Builder methods.
228 //
229 
230 unsigned Merger::buildLattices(unsigned e, unsigned idx) {
231   Kind kind = tensorExps[e].kind;
232   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
233     // Either the index is really used in the tensor expression, or it is
234     // set to the undefined index in that dimension. An invariant expression
235     // is set to a synthetic tensor with undefined indices only.
236     unsigned s = addSet();
237     unsigned t = kind == Kind::kTensor ? tensorExps[e].e0 : syntheticTensor;
238     latSets[s].push_back(addLat(t, idx, e));
239     return s;
240   }
241   unsigned s0 = buildLattices(tensorExps[e].e0, idx);
242   unsigned s1 = buildLattices(tensorExps[e].e1, idx);
243   switch (kind) {
244   case Kind::kTensor:
245   case Kind::kInvariant:
246     llvm_unreachable("handled above");
247   case Kind::kMulF:
248   case Kind::kMulI:
249     return takeConj(kind, s0, s1);
250   case Kind::kAddF:
251   case Kind::kAddI:
252     return takeDisj(kind, s0, s1);
253   }
254   llvm_unreachable("unexpected expression kind");
255 }
256 
257 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
258   Operation *yield = op.region().front().getTerminator();
259   return buildTensorExp(op, yield->getOperand(0));
260 }
261 
262 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
263   if (auto arg = val.dyn_cast<BlockArgument>()) {
264     unsigned argN = arg.getArgNumber();
265     // Any argument of the generic op that is not marked as a scalar
266     // argument is considered a tensor, indexed by the implicit loop
267     // bounds. This includes rank-0 tensor arguments.
268     if (arg.getOwner()->getParentOp() == op) {
269       OpOperand *t = op.getInputAndOutputOperands()[argN];
270       if (!op.isScalar(t))
271         return addExp(Kind::kTensor, argN);
272       val = t->get(); // get scalar value
273     }
274     // Any other argument (marked as scalar argument for the generic op
275     // or belonging to an enveloping op) is considered invariant.
276     return addExp(Kind::kInvariant, val);
277   }
278   // Something defined outside is invariant.
279   Operation *def = val.getDefiningOp();
280   if (def->getBlock() != &op.region().front())
281     return addExp(Kind::kInvariant, val);
282   // Construct binary operations if subexpressions could be built.
283   if (def->getNumOperands() == 2) {
284     auto x = buildTensorExp(op, def->getOperand(0));
285     auto y = buildTensorExp(op, def->getOperand(1));
286     if (x.hasValue() && y.hasValue()) {
287       unsigned e0 = x.getValue();
288       unsigned e1 = y.getValue();
289       if (isa<MulFOp>(def))
290         return addExp(Kind::kMulF, e0, e1);
291       if (isa<MulIOp>(def))
292         return addExp(Kind::kMulI, e0, e1);
293       if (isa<AddFOp>(def))
294         return addExp(Kind::kAddF, e0, e1);
295       if (isa<AddIOp>(def))
296         return addExp(Kind::kAddI, e0, e1);
297     }
298   }
299   // Cannot build.
300   return None;
301 }
302 
303 } // namespace sparse_tensor
304 } // namespace mlir
305