1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
2 #include "gmock/gmock.h"
3 #include "gtest/gtest.h"
4 #include <memory>
5 
6 using namespace mlir::sparse_tensor;
7 
8 namespace {
9 
10 /// Simple recursive data structure used to match expressions in Mergers.
11 struct Pattern {
12   Kind kind;
13 
14   /// Expressions representing tensors simply have a tensor number.
15   unsigned tensorNum;
16 
17   /// Tensor operations point to their children.
18   std::shared_ptr<Pattern> e0;
19   std::shared_ptr<Pattern> e1;
20 
21   /// Constructors.
22   /// Rather than using these, please use the readable helper constructor
23   /// functions below to make tests more readable.
24   Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
25   Pattern(Kind kind, const std::shared_ptr<Pattern> &e0,
26           const std::shared_ptr<Pattern> &e1)
27       : kind(kind), e0(e0), e1(e1) {
28     assert(kind >= Kind::kMulF);
29     assert(e0 && e1);
30   }
31 };
32 
33 ///
34 /// Readable Pattern builder functions.
35 /// These should be preferred over the actual constructors.
36 ///
37 
38 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
39   return std::make_shared<Pattern>(tensorNum);
40 }
41 
42 static std::shared_ptr<Pattern>
43 addfPattern(const std::shared_ptr<Pattern> &e0,
44             const std::shared_ptr<Pattern> &e1) {
45   return std::make_shared<Pattern>(Kind::kAddF, e0, e1);
46 }
47 
48 static std::shared_ptr<Pattern>
49 mulfPattern(const std::shared_ptr<Pattern> &e0,
50             const std::shared_ptr<Pattern> &e1) {
51   return std::make_shared<Pattern>(Kind::kMulF, e0, e1);
52 }
53 
54 class MergerTestBase : public ::testing::Test {
55 protected:
56   MergerTestBase(unsigned numTensors, unsigned numLoops)
57       : numTensors(numTensors), numLoops(numLoops),
58         merger(numTensors, numLoops) {}
59 
60   ///
61   /// Expression construction helpers.
62   ///
63 
64   unsigned tensor(unsigned tensor) {
65     return merger.addExp(Kind::kTensor, tensor);
66   }
67 
68   unsigned addf(unsigned e0, unsigned e1) {
69     return merger.addExp(Kind::kAddF, e0, e1);
70   }
71 
72   unsigned mulf(unsigned e0, unsigned e1) {
73     return merger.addExp(Kind::kMulF, e0, e1);
74   }
75 
76   ///
77   /// Comparison helpers.
78   ///
79 
80   /// For readability of tests.
81   unsigned lat(unsigned lat) { return lat; }
82 
83   /// Returns true if a lattice point with an expression matching the given
84   /// pattern and bits matching the given bits is present in lattice points
85   /// [p, p+n) of lattice set s. This is useful for testing partial ordering
86   /// constraints between lattice points. We generally know how contiguous
87   /// groups of lattice points should be ordered with respect to other groups,
88   /// but there is no required ordering within groups.
89   bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
90                            const std::shared_ptr<Pattern> &pattern,
91                            const llvm::BitVector &bits) {
92     for (unsigned i = p; i < p + n; ++i) {
93       if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
94           compareBits(s, i, bits))
95         return true;
96     }
97     return false;
98   }
99 
100   /// Wrapper over latPointWithinRange for readability of tests.
101   void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
102                                  const std::shared_ptr<Pattern> &pattern,
103                                  const llvm::BitVector &bits) {
104     EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits));
105   }
106 
107   /// Wrapper over expectLatPointWithinRange for a single lat point.
108   void expectLatPoint(unsigned s, unsigned p,
109                       const std::shared_ptr<Pattern> &pattern,
110                       const llvm::BitVector &bits) {
111     EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits));
112   }
113 
114   /// Converts a vector of (loop, tensor) pairs to a bitvector with the
115   /// corresponding bits set.
116   llvm::BitVector
117   loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) {
118     llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false);
119     for (auto l : loops) {
120       auto loop = std::get<0>(l);
121       auto tensor = std::get<1>(l);
122       testBits.set(numTensors * loop + tensor);
123     }
124     return testBits;
125   }
126 
127   /// Returns true if the bits of lattice point p in set s match the given bits.
128   bool compareBits(unsigned s, unsigned p, const llvm::BitVector &bits) {
129     return merger.lat(merger.set(s)[p]).bits == bits;
130   }
131 
132   /// Check that there are n lattice points in set s.
133   void expectNumLatPoints(unsigned s, unsigned n) {
134     EXPECT_THAT(merger.set(s).size(), n);
135   }
136 
137   /// Compares expressions for equality. Equality is defined recursively as:
138   /// - Two expressions can only be equal if they have the same Kind.
139   /// - Two binary expressions are equal if they have the same Kind and their
140   ///     children are equal.
141   /// - Expressions with Kind invariant or tensor are equal if they have the
142   ///     same expression id.
143   bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
144     auto tensorExp = merger.exp(e);
145     if (tensorExp.kind != pattern->kind)
146       return false;
147     assert(tensorExp.kind != Kind::kInvariant &&
148            "Invariant comparison not yet supported");
149     switch (tensorExp.kind) {
150     case Kind::kTensor:
151       return tensorExp.tensor == pattern->tensorNum;
152     case Kind::kAbsF:
153     case Kind::kCeilF:
154     case Kind::kFloorF:
155     case Kind::kNegF:
156     case Kind::kNegI:
157       return compareExpression(tensorExp.children.e0, pattern->e0);
158     case Kind::kMulF:
159     case Kind::kMulI:
160     case Kind::kDivF:
161     case Kind::kDivS:
162     case Kind::kDivU:
163     case Kind::kAddF:
164     case Kind::kAddI:
165     case Kind::kSubF:
166     case Kind::kSubI:
167     case Kind::kAndI:
168     case Kind::kOrI:
169     case Kind::kXorI:
170       return compareExpression(tensorExp.children.e0, pattern->e0) &&
171              compareExpression(tensorExp.children.e1, pattern->e1);
172     default:
173       llvm_unreachable("Unhandled Kind");
174     }
175   }
176 
177   unsigned numTensors;
178   unsigned numLoops;
179   Merger merger;
180 };
181 
182 class MergerTest3T1L : public MergerTestBase {
183 protected:
184   // Our three tensors (two inputs, one output).
185   const unsigned t0 = 0, t1 = 1, t2 = 2;
186 
187   // Our single loop.
188   const unsigned l0 = 0;
189 
190   MergerTest3T1L() : MergerTestBase(3, 1) {
191     // Tensor 0: sparse input vector.
192     merger.addExp(Kind::kTensor, t0, -1u);
193     merger.setDim(t0, l0, Dim::kSparse);
194 
195     // Tensor 1: sparse input vector.
196     merger.addExp(Kind::kTensor, t1, -1u);
197     merger.setDim(t1, l0, Dim::kSparse);
198 
199     // Tensor 2: dense output vector.
200     merger.addExp(Kind::kTensor, t2, -1u);
201     merger.setDim(t2, l0, Dim::kDense);
202   }
203 };
204 
205 } // namespace
206 
207 /// Vector addition of 2 vectors, i.e.:
208 ///   a(i) = b(i) + c(i)
209 /// which should form the 3 lattice points
210 /// {
211 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
212 ///   lat( i_00 / tensor_0 )
213 ///   lat( i_01 / tensor_1 )
214 /// }
215 /// and after optimization, will reduce to the 2 lattice points
216 /// {
217 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
218 ///   lat( i_00 / tensor_0 )
219 /// }
220 TEST_F(MergerTest3T1L, VectorAdd2) {
221   // Construct expression.
222   auto e = addf(tensor(t0), tensor(t1));
223 
224   // Build lattices and check.
225   auto s = merger.buildLattices(e, l0);
226   expectNumLatPoints(s, 3);
227   expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
228                  loopsToBits({{l0, t0}, {l0, t1}}));
229   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
230                             loopsToBits({{l0, t0}}));
231   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
232                             loopsToBits({{l0, t1}}));
233 
234   // Optimize lattices and check.
235   s = merger.optimizeSet(s);
236   expectNumLatPoints(s, 3);
237   expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
238                  loopsToBits({{l0, t0}, {l0, t1}}));
239   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
240                             loopsToBits({{l0, t0}}));
241   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
242                             loopsToBits({{l0, t1}}));
243 }
244 
245 /// Vector multiplication of 2 vectors, i.e.:
246 ///   a(i) = b(i) * c(i)
247 /// which should form the single lattice point
248 /// {
249 ///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
250 /// }
251 TEST_F(MergerTest3T1L, VectorMul2) {
252   // Construct expression.
253   auto e = mulf(t0, t1);
254 
255   // Build lattices and check.
256   auto s = merger.buildLattices(e, l0);
257   expectNumLatPoints(s, 1);
258   expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
259                  loopsToBits({{l0, t0}, {l0, t1}}));
260 
261   // Optimize lattices and check.
262   s = merger.optimizeSet(s);
263   expectNumLatPoints(s, 1);
264   expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
265                  loopsToBits({{l0, t0}, {l0, t1}}));
266 }
267