1 //===- Merger.cpp - Implementation of iteration lattices ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
10 
11 namespace mlir {
12 namespace sparse_tensor {
13 
14 unsigned Merger::addExp(Kind k, unsigned e0, unsigned e1, Value v) {
15   unsigned e = tensorExps.size();
16   tensorExps.push_back(TensorExp(k, e0, e1, v));
17   return e;
18 }
19 
20 unsigned Merger::addLat(unsigned t, unsigned i, unsigned e) {
21   assert(t < numTensors && i < numLoops);
22   unsigned p = latPoints.size();
23   latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
24   return p;
25 }
26 
27 unsigned Merger::addSet() {
28   unsigned s = latSets.size();
29   latSets.emplace_back(SmallVector<unsigned, 16>());
30   return s;
31 }
32 
33 unsigned Merger::conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
34   unsigned p = latPoints.size();
35   llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
36   nb |= latPoints[p1].bits;
37   unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
38   latPoints.push_back(LatPoint(nb, e));
39   return p;
40 }
41 
42 unsigned Merger::takeConj(Kind kind, unsigned s0, unsigned s1) {
43   unsigned s = addSet();
44   for (unsigned p0 : latSets[s0])
45     for (unsigned p1 : latSets[s1])
46       latSets[s].push_back(conjLatPoint(kind, p0, p1));
47   return s;
48 }
49 
50 unsigned Merger::takeDisj(Kind kind, unsigned s0, unsigned s1) {
51   unsigned s = takeConj(kind, s0, s1);
52   for (unsigned p : latSets[s0])
53     latSets[s].push_back(p);
54   for (unsigned p : latSets[s1])
55     latSets[s].push_back(p);
56   return s;
57 }
58 
59 unsigned Merger::optimizeSet(unsigned s0) {
60   unsigned s = addSet();
61   assert(latSets[s0].size() != 0);
62   unsigned p0 = latSets[s0][0];
63   for (unsigned p1 : latSets[s0]) {
64     bool add = true;
65     if (p0 != p1) {
66       // Is this a straightforward copy?
67       unsigned e = latPoints[p1].exp;
68       if (exp(e).kind == Kind::kTensor && exp(e).e0 == outTensor)
69         continue;
70       // Conjunction already covered?
71       for (unsigned p2 : latSets[s]) {
72         assert(!latGT(p1, p2)); // Lj => Li would be bad
73         if (onlyDenseDiff(p2, p1)) {
74           add = false;
75           break;
76         }
77       }
78       assert(!add || latGT(p0, p1));
79     }
80     if (add)
81       latSets[s].push_back(p1);
82   }
83   for (unsigned p : latSets[s])
84     latPoints[p].simple = simplifyCond(s, p);
85   return s;
86 }
87 
88 llvm::BitVector Merger::simplifyCond(unsigned s, unsigned p0) {
89   // First determine if this lattice point is a *singleton*, i.e.,
90   // the last point in a lattice, no other is less than this one.
91   bool isSingleton = true;
92   for (unsigned p1 : latSets[s]) {
93     if (p0 != p1 && latGT(p0, p1)) {
94       isSingleton = false;
95       break;
96     }
97   }
98   // Now apply the two basic rules.
99   llvm::BitVector simple = latPoints[p0].bits;
100   bool reset = isSingleton && hasAnyDimOf(simple, Dim::kSparse);
101   for (unsigned b = 0, be = simple.size(); b < be; b++) {
102     if (simple[b] && !isDim(b, Dim::kSparse)) {
103       if (reset)
104         simple.reset(b);
105       reset = true;
106     }
107   }
108   return simple;
109 }
110 
111 bool Merger::latGT(unsigned i, unsigned j) const {
112   const llvm::BitVector &bitsi = latPoints[i].bits;
113   const llvm::BitVector &bitsj = latPoints[j].bits;
114   assert(bitsi.size() == bitsj.size());
115   if (bitsi.count() > bitsj.count()) {
116     for (unsigned b = 0, be = bitsj.size(); b < be; b++)
117       if (bitsj[b] && !bitsi[b])
118         return false;
119     return true;
120   }
121   return false;
122 }
123 
124 bool Merger::onlyDenseDiff(unsigned i, unsigned j) {
125   llvm::BitVector tmp = latPoints[j].bits;
126   tmp ^= latPoints[i].bits;
127   return !hasAnyDimOf(tmp, Dim::kSparse);
128 }
129 
130 bool Merger::hasAnyDimOf(const llvm::BitVector &bits, Dim d) const {
131   for (unsigned b = 0, be = bits.size(); b < be; b++)
132     if (bits[b] && isDim(b, d))
133       return true;
134   return false;
135 }
136 
137 } // namespace sparse_tensor
138 } // namespace mlir
139