1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 
11 #include "mlir/IR/Operation.h"
12 #include "llvm/Support/Debug.h"
13 
14 namespace mlir {
15 namespace sparse_tensor {
16 
17 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
18   unsigned e = tensorExps.size();
19   tensorExps.push_back(TensorExp(k, e0, e1, v));
20   return e;
21 }
22 
23 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
24   assert(t < numTensors && i < numLoops);
25   unsigned p = latPoints.size();
26   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
27   return p;
28 }
29 
30 unsigned Merger::addSet() {
31   unsigned s = latSets.size();
32   latSets.emplace_back(SmallVector<unsigned, 16>());
33   return s;
34 }
35 
36 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
37   unsigned p = latPoints.size();
38   llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
39   nb |= latPoints[p1].bits;
40   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
41   latPoints.push_back(LatPoint(nb, e));
42   return p;
43 }
44 
45 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
46   unsigned s = addSet();
47   for (unsigned p0 : latSets[s0])
48     for (unsigned p1 : latSets[s1])
49       latSets[s].push_back(conjLatPoint(kind, p0, p1));
50   return s;
51 }
52 
53 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
54   unsigned s = takeConj(kind, s0, s1);
55   for (unsigned p : latSets[s0])
56     latSets[s].push_back(p);
57   for (unsigned p : latSets[s1])
58     latSets[s].push_back(p);
59   return s;
60 }
61 
62 unsigned Merger::optimizeSet(unsigned s0) {
63   unsigned s = addSet();
64   assert(latSets[s0].size() != 0);
65   unsigned p0 = latSets[s0][0];
66   for (unsigned p1 : latSets[s0]) {
67     bool add = true;
68     if (p0 != p1) {
69       // Is this a straightforward copy?
70       unsigned e = latPoints[p1].exp;
71       if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
72         continue;
73       // Conjunction already covered?
74       for (unsigned p2 : latSets[s]) {
75         assert(!latGT(p1, p2)); // Lj => Li would be bad
76         if (onlyDenseDiff(p2, p1)) {
77           add = false;
78           break;
79         }
80       }
81       assert(!add || latGT(p0, p1));
82     }
83     if (add)
84       latSets[s].push_back(p1);
85   }
86   for (unsigned p : latSets[s])
87     latPoints[p].simple = simplifyCond(s, p);
88   return s;
89 }
90 
91 llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
92   // First determine if this lattice point is a *singleton*, i.e.,
93   // the last point in a lattice, no other is less than this one.
94   bool isSingleton = true;
95   for (unsigned p1 : latSets[s]) {
96     if (p0 != p1 && latGT(p0, p1)) {
97       isSingleton = false;
98       break;
99     }
100   }
101   // Now apply the two basic rules.
102   llvm::BitVector simple = latPoints[p0].bits;
103   bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
104   for (unsigned b = 0, be = simple.size(); b < be; b++) {
105     if (simple[b] && !isDim(b, Dim::kSparse)) {
106       if (reset)
107         simple.reset(b);
108       reset = true;
109     }
110   }
111   return simple;
112 }
113 
114 bool Merger::latGT(unsigned i, unsigned j) const {
115   const llvm::BitVector &bitsi = latPoints[i].bits;
116   const llvm::BitVector &bitsj = latPoints[j].bits;
117   assert(bitsi.size() == bitsj.size());
118   if (bitsi.count() > bitsj.count()) {
119     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
120       if (bitsj[b] && !bitsi[b])
121         return false;
122     return true;
123   }
124   return false;
125 }
126 
127 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
128   llvm::BitVector tmp = latPoints[j].bits;
129   tmp ^= latPoints[i].bits;
130   return !hasAnyDimOf(tmp, Dim::kSparse);
131 }
132 
133 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
134   for (unsigned b = 0, be = bits.size(); b < be; b++)
135     if (bits[b] && isDim(b, d))
136       return true;
137   return false;
138 }
139 
140 unsigned Merger::buildLattices(unsigned e, unsigned idx) {
141   Kind kind = exp(e).kind;
142   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
143     // Either the index is really used in the tensor expression, or it is
144     // set to the undefined index in that dimension. An invariant expression
145     // is set to a synthetic tensor with undefined indices only.
146     unsigned s = addSet();
147     unsigned t = kind == Kind::kTensor ? exp(e).e0 : syntheticTensor;
148     set(s).push_back(addLat(t, idx, e));
149     return s;
150   }
151   unsigned s0 = buildLattices(exp(e).e0, idx);
152   unsigned s1 = buildLattices(exp(e).e1, idx);
153   switch (kind) {
154   case Kind::kTensor:
155   case Kind::kInvariant:
156     llvm_unreachable("handled above");
157   case Kind::kMulF:
158   case Kind::kMulI:
159     return takeConj(kind, s0, s1);
160   case Kind::kAddF:
161   case Kind::kAddI:
162     return takeDisj(kind, s0, s1);
163   }
164   llvm_unreachable("unexpected expression kind");
165 }
166 
167 #ifndef NDEBUG
168 
169 //
170 // Print methods (for debugging).
171 //
172 
173 void Merger::dumpExp(unsigned e) const {
174   switch (tensorExps[e].kind) {
175   case Kind::kTensor:
176     llvm::dbgs() << "tensor_" << tensorExps[e].e0;
177     break;
178   case Kind::kInvariant:
179     llvm::dbgs() << "invariant";
180     break;
181   default:
182   case Kind::kMulI:
183     llvm::dbgs() << "(";
184     dumpExp(tensorExps[e].e0);
185     llvm::dbgs() << " * ";
186     dumpExp(tensorExps[e].e1);
187     llvm::dbgs() << ")";
188     break;
189   case Kind::kAddF:
190   case Kind::kAddI:
191     llvm::dbgs() << "(";
192     dumpExp(tensorExps[e].e0);
193     llvm::dbgs() << " + ";
194     dumpExp(tensorExps[e].e1);
195     llvm::dbgs() << ")";
196     break;
197   }
198 }
199 
200 void Merger::dumpLat(unsigned p) const {
201   llvm::dbgs() << "lat(";
202   dumpBits(latPoints[p].bits);
203   llvm::dbgs() << " :";
204   dumpBits(latPoints[p].simple);
205   llvm::dbgs() << " / ";
206   dumpExp(latPoints[p].exp);
207   llvm::dbgs() << " )\n";
208 }
209 
210 void Merger::dumpSet(unsigned s) const {
211   llvm::dbgs() << "{ #" << latSets[s].size() << "\n";
212   for (unsigned p : latSets[s]) {
213     llvm::dbgs() << "  ";
214     dumpLat(p);
215   }
216   llvm::dbgs() << "}\n";
217 }
218 
219 void Merger::dumpBits(const llvm::BitVector &bits) const {
220   for (unsigned b = 0, be = bits.size(); b < be; b++) {
221     if (bits[b]) {
222       unsigned t = tensor(b);
223       unsigned i = index(b);
224       llvm::dbgs() << " i_" << t << "_" << i << "_";
225       switch (dims[t][i]) {
226       case Dim::kSparse:
227         llvm::dbgs() << "S";
228         break;
229       case Dim::kDense:
230         llvm::dbgs() << "D";
231         break;
232       case Dim::kSingle:
233         llvm::dbgs() << "T";
234         break;
235       case Dim::kUndef:
236         llvm::dbgs() << "U";
237         break;
238       }
239     }
240   }
241 }
242 
243 #endif // NDEBUG
244 
245 } // namespace sparse_tensor
246 } // namespace mlir
247