1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 
11 #include "mlir/IR/Operation.h"
12 #include "llvm/Support/Debug.h"
13 
14 namespace mlir {
15 namespace sparse_tensor {
16 
17 //
18 // Lattice methods.
19 //
20 
21 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
22   unsigned e = tensorExps.size();
23   tensorExps.push_back(TensorExp(k, e0, e1, v));
24   return e;
25 }
26 
27 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
28   assert(t < numTensors && i < numLoops);
29   unsigned p = latPoints.size();
30   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
31   return p;
32 }
33 
34 unsigned Merger::addSet() {
35   unsigned s = latSets.size();
36   latSets.emplace_back(SmallVector<unsigned, 16>());
37   return s;
38 }
39 
40 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
41   unsigned p = latPoints.size();
42   llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
43   nb |= latPoints[p1].bits;
44   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
45   latPoints.push_back(LatPoint(nb, e));
46   return p;
47 }
48 
49 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
50   unsigned s = addSet();
51   for (unsigned p0 : latSets[s0])
52     for (unsigned p1 : latSets[s1])
53       latSets[s].push_back(conjLatPoint(kind, p0, p1));
54   return s;
55 }
56 
57 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
58   unsigned s = takeConj(kind, s0, s1);
59   for (unsigned p : latSets[s0])
60     latSets[s].push_back(p);
61   for (unsigned p : latSets[s1])
62     latSets[s].push_back(p);
63   return s;
64 }
65 
66 unsigned Merger::optimizeSet(unsigned s0) {
67   unsigned s = addSet();
68   assert(latSets[s0].size() != 0);
69   unsigned p0 = latSets[s0][0];
70   for (unsigned p1 : latSets[s0]) {
71     bool add = true;
72     if (p0 != p1) {
73       // Is this a straightforward copy?
74       unsigned e = latPoints[p1].exp;
75       if (tensorExps[e].kind == Kind::kTensor &&
76           tensorExps[e].tensor == outTensor)
77         continue;
78       // Conjunction already covered?
79       for (unsigned p2 : latSets[s]) {
80         assert(!latGT(p1, p2)); // Lj => Li would be bad
81         if (onlyDenseDiff(p2, p1)) {
82           add = false;
83           break;
84         }
85       }
86       assert(!add || latGT(p0, p1));
87     }
88     if (add)
89       latSets[s].push_back(p1);
90   }
91   for (unsigned p : latSets[s])
92     latPoints[p].simple = simplifyCond(s, p);
93   return s;
94 }
95 
96 llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
97   // First determine if this lattice point is a *singleton*, i.e.,
98   // the last point in a lattice, no other is less than this one.
99   bool isSingleton = true;
100   for (unsigned p1 : latSets[s]) {
101     if (p0 != p1 && latGT(p0, p1)) {
102       isSingleton = false;
103       break;
104     }
105   }
106   // Now apply the two basic rules.
107   llvm::BitVector simple = latPoints[p0].bits;
108   bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
109   for (unsigned b = 0, be = simple.size(); b < be; b++) {
110     if (simple[b] && !isDim(b, Dim::kSparse)) {
111       if (reset)
112         simple.reset(b);
113       reset = true;
114     }
115   }
116   return simple;
117 }
118 
119 bool Merger::latGT(unsigned i, unsigned j) const {
120   const llvm::BitVector &bitsi = latPoints[i].bits;
121   const llvm::BitVector &bitsj = latPoints[j].bits;
122   assert(bitsi.size() == bitsj.size());
123   if (bitsi.count() > bitsj.count()) {
124     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
125       if (bitsj[b] && !bitsi[b])
126         return false;
127     return true;
128   }
129   return false;
130 }
131 
132 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
133   llvm::BitVector tmp = latPoints[j].bits;
134   tmp ^= latPoints[i].bits;
135   return !hasAnyDimOf(tmp, Dim::kSparse);
136 }
137 
138 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
139   for (unsigned b = 0, be = bits.size(); b < be; b++)
140     if (bits[b] && isDim(b, d))
141       return true;
142   return false;
143 }
144 
145 #ifndef NDEBUG
146 
147 //
148 // Print methods (for debugging).
149 //
150 
151 void Merger::dumpExp(unsigned e) const {
152   switch (tensorExps[e].kind) {
153   case Kind::kTensor:
154     if (tensorExps[e].tensor == syntheticTensor)
155       llvm::dbgs() << "synthetic_";
156     else if (tensorExps[e].tensor == outTensor)
157       llvm::dbgs() << "output_";
158     llvm::dbgs() << "tensor_" << tensorExps[e].tensor;
159     break;
160   case Kind::kInvariant:
161     llvm::dbgs() << "invariant";
162     break;
163   default:
164   case Kind::kMulI:
165     llvm::dbgs() << "(";
166     dumpExp(tensorExps[e].children.e0);
167     llvm::dbgs() << " * ";
168     dumpExp(tensorExps[e].children.e1);
169     llvm::dbgs() << ")";
170     break;
171   case Kind::kAddF:
172   case Kind::kAddI:
173     llvm::dbgs() << "(";
174     dumpExp(tensorExps[e].children.e0);
175     llvm::dbgs() << " + ";
176     dumpExp(tensorExps[e].children.e1);
177     llvm::dbgs() << ")";
178     break;
179   }
180 }
181 
182 void Merger::dumpLat(unsigned p) const {
183   llvm::dbgs() << "lat(";
184   dumpBits(latPoints[p].bits);
185   llvm::dbgs() << " :";
186   dumpBits(latPoints[p].simple);
187   llvm::dbgs() << " / ";
188   dumpExp(latPoints[p].exp);
189   llvm::dbgs() << " )\n";
190 }
191 
192 void Merger::dumpSet(unsigned s) const {
193   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
194   for (unsigned p : latSets[s]) {
195     llvm::dbgs() << "  ";
196     dumpLat(p);
197   }
198   llvm::dbgs() << "}\n";
199 }
200 
201 void Merger::dumpBits(const llvm::BitVector &bits) const {
202   for (unsigned b = 0, be = bits.size(); b < be; b++) {
203     if (bits[b]) {
204       unsigned t = tensor(b);
205       unsigned i = index(b);
206       llvm::dbgs() << " i_" << t << "_" << i << "_";
207       switch (dims[t][i]) {
208       case Dim::kSparse:
209         llvm::dbgs() << "S";
210         break;
211       case Dim::kDense:
212         llvm::dbgs() << "D";
213         break;
214       case Dim::kSingle:
215         llvm::dbgs() << "T";
216         break;
217       case Dim::kUndef:
218         llvm::dbgs() << "U";
219         break;
220       }
221     }
222   }
223 }
224 
225 #endif // NDEBUG
226 
227 //
228 // Builder methods.
229 //
230 
231 unsigned Merger::buildLattices(unsigned e, unsigned idx) {
232   Kind kind = tensorExps[e].kind;
233   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
234     // Either the index is really used in the tensor expression, or it is
235     // set to the undefined index in that dimension. An invariant expression
236     // is set to a synthetic tensor with undefined indices only.
237     unsigned s = addSet();
238     unsigned t =
239         kind == Kind::kTensor ? tensorExps[e].children.e0 : syntheticTensor;
240     latSets[s].push_back(addLat(t, idx, e));
241     return s;
242   }
243   unsigned s0 = buildLattices(tensorExps[e].children.e0, idx);
244   unsigned s1 = buildLattices(tensorExps[e].children.e1, idx);
245   switch (kind) {
246   case Kind::kTensor:
247   case Kind::kInvariant:
248     llvm_unreachable("handled above");
249   case Kind::kMulF:
250   case Kind::kMulI:
251     return takeConj(kind, s0, s1);
252   case Kind::kAddF:
253   case Kind::kAddI:
254     return takeDisj(kind, s0, s1);
255   }
256   llvm_unreachable("unexpected expression kind");
257 }
258 
259 Optional<unsigned> Merger::buildTensorExpFromLinalg(linalg::GenericOp op) {
260   Operation *yield = op.region().front().getTerminator();
261   return buildTensorExp(op, yield->getOperand(0));
262 }
263 
264 Optional<unsigned> Merger::buildTensorExp(linalg::GenericOp op, Value val) {
265   if (auto arg = val.dyn_cast<BlockArgument>()) {
266     unsigned argN = arg.getArgNumber();
267     // Any argument of the generic op that is not marked as a scalar
268     // argument is considered a tensor, indexed by the implicit loop
269     // bounds. This includes rank-0 tensor arguments.
270     if (arg.getOwner()->getParentOp() == op) {
271       OpOperand *t = op.getInputAndOutputOperands()[argN];
272       if (!op.isScalar(t))
273         return addExp(Kind::kTensor, argN);
274       val = t->get(); // get scalar value
275     }
276     // Any other argument (marked as scalar argument for the generic op
277     // or belonging to an enveloping op) is considered invariant.
278     return addExp(Kind::kInvariant, val);
279   }
280   // Something defined outside is invariant.
281   Operation *def = val.getDefiningOp();
282   if (def->getBlock() != &op.region().front())
283     return addExp(Kind::kInvariant, val);
284   // Construct binary operations if subexpressions could be built.
285   if (def->getNumOperands() == 2) {
286     auto x = buildTensorExp(op, def->getOperand(0));
287     auto y = buildTensorExp(op, def->getOperand(1));
288     if (x.hasValue() && y.hasValue()) {
289       unsigned e0 = x.getValue();
290       unsigned e1 = y.getValue();
291       if (isa<MulFOp>(def))
292         return addExp(Kind::kMulF, e0, e1);
293       if (isa<MulIOp>(def))
294         return addExp(Kind::kMulI, e0, e1);
295       if (isa<AddFOp>(def))
296         return addExp(Kind::kAddF, e0, e1);
297       if (isa<AddIOp>(def))
298         return addExp(Kind::kAddI, e0, e1);
299     }
300   }
301   // Cannot build.
302   return None;
303 }
304 
305 } // namespace sparse_tensor
306 } // namespace mlir
307