1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
2 #include "gmock/gmock.h"
3 #include "gtest/gtest.h"
4 #include <memory>
5 
6 using namespace mlir;
7 using namespace mlir::sparse_tensor;
8 
9 namespace {
10 
11 /// Simple recursive data structure used to match expressions in Mergers.
12 struct Pattern {
13   Kind kind;
14 
15   /// Expressions representing tensors simply have a tensor number.
16   unsigned tensorNum;
17 
18   /// Tensor operations point to their children.
19   std::shared_ptr<Pattern> e0;
20   std::shared_ptr<Pattern> e1;
21 
22   /// Constructors.
23   /// Rather than using these, please use the readable helper constructor
24   /// functions below to make tests more readable.
25   Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
26   Pattern(Kind kind, const std::shared_ptr<Pattern> &e0,
27           const std::shared_ptr<Pattern> &e1)
28       : kind(kind), e0(e0), e1(e1) {
29     assert(kind >= Kind::kMulF);
30     assert(e0 && e1);
31   }
32 };
33 
34 ///
35 /// Readable Pattern builder functions.
36 /// These should be preferred over the actual constructors.
37 ///
38 
39 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
40   return std::make_shared<Pattern>(tensorNum);
41 }
42 
43 static std::shared_ptr<Pattern>
44 addfPattern(const std::shared_ptr<Pattern> &e0,
45             const std::shared_ptr<Pattern> &e1) {
46   return std::make_shared<Pattern>(Kind::kAddF, e0, e1);
47 }
48 
49 static std::shared_ptr<Pattern>
50 mulfPattern(const std::shared_ptr<Pattern> &e0,
51             const std::shared_ptr<Pattern> &e1) {
52   return std::make_shared<Pattern>(Kind::kMulF, e0, e1);
53 }
54 
55 class MergerTestBase : public ::testing::Test {
56 protected:
57   MergerTestBase(unsigned numTensors, unsigned numLoops)
58       : numTensors(numTensors), numLoops(numLoops),
59         merger(numTensors, numLoops) {}
60 
61   ///
62   /// Expression construction helpers.
63   ///
64 
65   unsigned tensor(unsigned tensor) {
66     return merger.addExp(Kind::kTensor, tensor);
67   }
68 
69   unsigned addf(unsigned e0, unsigned e1) {
70     return merger.addExp(Kind::kAddF, e0, e1);
71   }
72 
73   unsigned mulf(unsigned e0, unsigned e1) {
74     return merger.addExp(Kind::kMulF, e0, e1);
75   }
76 
77   ///
78   /// Comparison helpers.
79   ///
80 
81   /// For readability of tests.
82   unsigned lat(unsigned lat) { return lat; }
83 
84   /// Returns true if a lattice point with an expression matching the given
85   /// pattern and bits matching the given bits is present in lattice points
86   /// [p, p+n) of lattice set s. This is useful for testing partial ordering
87   /// constraints between lattice points. We generally know how contiguous
88   /// groups of lattice points should be ordered with respect to other groups,
89   /// but there is no required ordering within groups.
90   bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
91                            const std::shared_ptr<Pattern> &pattern,
92                            const BitVector &bits) {
93     for (unsigned i = p; i < p + n; ++i) {
94       if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
95           compareBits(s, i, bits))
96         return true;
97     }
98     return false;
99   }
100 
101   /// Wrapper over latPointWithinRange for readability of tests.
102   void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
103                                  const std::shared_ptr<Pattern> &pattern,
104                                  const BitVector &bits) {
105     EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits));
106   }
107 
108   /// Wrapper over expectLatPointWithinRange for a single lat point.
109   void expectLatPoint(unsigned s, unsigned p,
110                       const std::shared_ptr<Pattern> &pattern,
111                       const BitVector &bits) {
112     EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits));
113   }
114 
115   /// Converts a vector of (loop, tensor) pairs to a bitvector with the
116   /// corresponding bits set.
117   BitVector
118   loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) {
119     BitVector testBits = BitVector(numTensors + 1, false);
120     for (auto l : loops) {
121       auto loop = std::get<0>(l);
122       auto tensor = std::get<1>(l);
123       testBits.set(numTensors * loop + tensor);
124     }
125     return testBits;
126   }
127 
128   /// Returns true if the bits of lattice point p in set s match the given bits.
129   bool compareBits(unsigned s, unsigned p, const BitVector &bits) {
130     return merger.lat(merger.set(s)[p]).bits == bits;
131   }
132 
133   /// Check that there are n lattice points in set s.
134   void expectNumLatPoints(unsigned s, unsigned n) {
135     EXPECT_THAT(merger.set(s).size(), n);
136   }
137 
138   /// Compares expressions for equality. Equality is defined recursively as:
139   /// - Operations are equal if they have the same kind and children.
140   /// - Leaf tensors are equal if they refer to the same tensor.
141   bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
142     auto tensorExp = merger.exp(e);
143     if (tensorExp.kind != pattern->kind)
144       return false;
145     switch (tensorExp.kind) {
146     // Leaf.
147     case kTensor:
148       return tensorExp.tensor == pattern->tensorNum;
149     case kInvariant:
150     case kIndex:
151       llvm_unreachable("invariant not handled yet");
152     // Unary operations.
153     case kAbsF:
154     case kAbsC:
155     case kCeilF:
156     case kFloorF:
157     case kSqrtF:
158     case kSqrtC:
159     case kExpm1F:
160     case kExpm1C:
161     case kLog1pF:
162     case kLog1pC:
163     case kSinF:
164     case kSinC:
165     case kTanhF:
166     case kTanhC:
167     case kNegF:
168     case kNegC:
169     case kNegI:
170     case kTruncF:
171     case kExtF:
172     case kCastFS:
173     case kCastFU:
174     case kCastSF:
175     case kCastUF:
176     case kCastS:
177     case kCastU:
178     case kCastIdx:
179     case kTruncI:
180     case kCIm:
181     case kCRe:
182     case kBitCast:
183     case kBinaryBranch:
184     case kUnary:
185     case kShlI:
186     case kBinary:
187       return compareExpression(tensorExp.children.e0, pattern->e0);
188     // Binary operations.
189     case kMulF:
190     case kMulC:
191     case kMulI:
192     case kDivF:
193     case kDivC:
194     case kDivS:
195     case kDivU:
196     case kAddF:
197     case kAddC:
198     case kAddI:
199     case kSubF:
200     case kSubC:
201     case kSubI:
202     case kAndI:
203     case kOrI:
204     case kXorI:
205     case kShrS:
206     case kShrU:
207       return compareExpression(tensorExp.children.e0, pattern->e0) &&
208              compareExpression(tensorExp.children.e1, pattern->e1);
209     }
210     llvm_unreachable("unexpected kind");
211   }
212 
213   unsigned numTensors;
214   unsigned numLoops;
215   Merger merger;
216 };
217 
218 class MergerTest3T1L : public MergerTestBase {
219 protected:
220   // Our three tensors (two inputs, one output).
221   const unsigned t0 = 0, t1 = 1, t2 = 2;
222 
223   // Our single loop.
224   const unsigned l0 = 0;
225 
226   MergerTest3T1L() : MergerTestBase(3, 1) {
227     // Tensor 0: sparse input vector.
228     merger.addExp(Kind::kTensor, t0, -1u);
229     merger.setDim(t0, l0, Dim::kSparse);
230 
231     // Tensor 1: sparse input vector.
232     merger.addExp(Kind::kTensor, t1, -1u);
233     merger.setDim(t1, l0, Dim::kSparse);
234 
235     // Tensor 2: dense output vector.
236     merger.addExp(Kind::kTensor, t2, -1u);
237     merger.setDim(t2, l0, Dim::kDense);
238   }
239 };
240 
241 } // namespace
242 
243 /// Vector addition of 2 vectors, i.e.:
244 ///   a(i) = b(i) + c(i)
245 /// which should form the 3 lattice points
246 /// {
247 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
248 ///   lat( i_00 / tensor_0 )
249 ///   lat( i_01 / tensor_1 )
250 /// }
251 /// and after optimization, will reduce to the 2 lattice points
252 /// {
253 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
254 ///   lat( i_00 / tensor_0 )
255 /// }
256 TEST_F(MergerTest3T1L, VectorAdd2) {
257   // Construct expression.
258   auto e = addf(tensor(t0), tensor(t1));
259 
260   // Build lattices and check.
261   auto s = merger.buildLattices(e, l0);
262   expectNumLatPoints(s, 3);
263   expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
264                  loopsToBits({{l0, t0}, {l0, t1}}));
265   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
266                             loopsToBits({{l0, t0}}));
267   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
268                             loopsToBits({{l0, t1}}));
269 
270   // Optimize lattices and check.
271   s = merger.optimizeSet(s);
272   expectNumLatPoints(s, 3);
273   expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
274                  loopsToBits({{l0, t0}, {l0, t1}}));
275   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
276                             loopsToBits({{l0, t0}}));
277   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
278                             loopsToBits({{l0, t1}}));
279 }
280 
281 /// Vector multiplication of 2 vectors, i.e.:
282 ///   a(i) = b(i) * c(i)
283 /// which should form the single lattice point
284 /// {
285 ///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
286 /// }
287 TEST_F(MergerTest3T1L, VectorMul2) {
288   // Construct expression.
289   auto e = mulf(t0, t1);
290 
291   // Build lattices and check.
292   auto s = merger.buildLattices(e, l0);
293   expectNumLatPoints(s, 1);
294   expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
295                  loopsToBits({{l0, t0}, {l0, t1}}));
296 
297   // Optimize lattices and check.
298   s = merger.optimizeSet(s);
299   expectNumLatPoints(s, 1);
300   expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
301                  loopsToBits({{l0, t0}, {l0, t1}}));
302 }
303