1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
2 #include "gmock/gmock.h"
3 #include "gtest/gtest.h"
4 #include <memory>
5 
6 using namespace mlir::sparse_tensor;
7 
8 namespace {
9 
10 /// Simple recursive data structure used to match expressions in Mergers.
11 struct Pattern {
12   Kind kind;
13 
14   /// Expressions representing tensors simply have a tensor number.
15   unsigned tensorNum;
16 
17   /// Tensor operations point to their children.
18   std::shared_ptr<Pattern> e0;
19   std::shared_ptr<Pattern> e1;
20 
21   /// Constructors.
22   /// Rather than using these, please use the readable helper constructor
23   /// functions below to make tests more readable.
24   Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
25   Pattern(Kind kind, std::shared_ptr<Pattern> e0, std::shared_ptr<Pattern> e1)
26       : kind(kind), e0(e0), e1(e1) {
27     assert(kind >= Kind::kMulF);
28     assert(e0 && e1);
29   }
30 };
31 
32 ///
33 /// Readable Pattern builder functions.
34 /// These should be preferred over the actual constructors.
35 ///
36 
37 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
38   return std::make_shared<Pattern>(tensorNum);
39 }
40 
41 static std::shared_ptr<Pattern> addfPattern(std::shared_ptr<Pattern> e0,
42                                             std::shared_ptr<Pattern> e1) {
43   return std::make_shared<Pattern>(Kind::kAddF, e0, e1);
44 }
45 
46 static std::shared_ptr<Pattern> mulfPattern(std::shared_ptr<Pattern> e0,
47                                             std::shared_ptr<Pattern> e1) {
48   return std::make_shared<Pattern>(Kind::kMulF, e0, e1);
49 }
50 
51 class MergerTestBase : public ::testing::Test {
52 protected:
53   MergerTestBase(unsigned numTensors, unsigned numLoops)
54       : numTensors(numTensors), numLoops(numLoops),
55         merger(numTensors, numLoops) {}
56 
57   ///
58   /// Expression construction helpers.
59   ///
60 
61   unsigned tensor(unsigned tensor) {
62     return merger.addExp(Kind::kTensor, tensor);
63   }
64 
65   unsigned addf(unsigned e0, unsigned e1) {
66     return merger.addExp(Kind::kAddF, e0, e1);
67   }
68 
69   unsigned mulf(unsigned e0, unsigned e1) {
70     return merger.addExp(Kind::kMulF, e0, e1);
71   }
72 
73   ///
74   /// Comparison helpers.
75   ///
76 
77   /// For readability of tests.
78   unsigned lat(unsigned lat) { return lat; }
79 
80   /// Returns true if a lattice point with an expression matching the given
81   /// pattern and bits matching the given bits is present in lattice points
82   /// [p, p+n) of lattice set s. This is useful for testing partial ordering
83   /// constraints between lattice points. We generally know how contiguous
84   /// groups of lattice points should be ordered with respect to other groups,
85   /// but there is no required ordering within groups.
86   bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
87                            std::shared_ptr<Pattern> pattern,
88                            llvm::BitVector bits) {
89     for (unsigned i = p; i < p + n; ++i) {
90       if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
91           compareBits(s, i, bits))
92         return true;
93     }
94     return false;
95   }
96 
97   /// Wrapper over latPointWithinRange for readability of tests.
98   void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
99                                  std::shared_ptr<Pattern> pattern,
100                                  llvm::BitVector bits) {
101     EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits));
102   }
103 
104   /// Wrapper over expectLatPointWithinRange for a single lat point.
105   void expectLatPoint(unsigned s, unsigned p, std::shared_ptr<Pattern> pattern,
106                       llvm::BitVector bits) {
107     EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits));
108   }
109 
110   /// Converts a vector of (loop, tensor) pairs to a bitvector with the
111   /// corresponding bits set.
112   llvm::BitVector
113   loopsToBits(std::vector<std::pair<unsigned, unsigned>> loops) {
114     llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false);
115     for (auto l : loops) {
116       auto loop = std::get<0>(l);
117       auto tensor = std::get<1>(l);
118       testBits.set(numTensors * loop + tensor);
119     }
120     return testBits;
121   }
122 
123   /// Returns true if the bits of lattice point p in set s match the given bits.
124   bool compareBits(unsigned s, unsigned p, llvm::BitVector bits) {
125     return merger.lat(merger.set(s)[p]).bits == bits;
126   }
127 
128   /// Check that there are n lattice points in set s.
129   void expectNumLatPoints(unsigned s, unsigned n) {
130     EXPECT_THAT(merger.set(s).size(), n);
131   }
132 
133   /// Compares expressions for equality. Equality is defined recursively as:
134   /// - Two expressions can only be equal if they have the same Kind.
135   /// - Two binary expressions are equal if they have the same Kind and their
136   ///     children are equal.
137   /// - Expressions with Kind invariant or tensor are equal if they have the
138   ///     same expression id.
139   bool compareExpression(unsigned e, std::shared_ptr<Pattern> pattern) {
140     auto tensorExp = merger.exp(e);
141     if (tensorExp.kind != pattern->kind)
142       return false;
143     assert(tensorExp.kind != Kind::kInvariant &&
144            "Invariant comparison not yet supported");
145     switch (tensorExp.kind) {
146     case Kind::kTensor:
147       return tensorExp.tensor == pattern->tensorNum;
148     case Kind::kAbsF:
149     case Kind::kCeilF:
150     case Kind::kFloorF:
151     case Kind::kNegF:
152     case Kind::kNegI:
153       return compareExpression(tensorExp.children.e0, pattern->e0);
154     case Kind::kMulF:
155     case Kind::kMulI:
156     case Kind::kDivF:
157     case Kind::kDivS:
158     case Kind::kDivU:
159     case Kind::kAddF:
160     case Kind::kAddI:
161     case Kind::kSubF:
162     case Kind::kSubI:
163     case Kind::kAndI:
164     case Kind::kOrI:
165     case Kind::kXorI:
166       return compareExpression(tensorExp.children.e0, pattern->e0) &&
167              compareExpression(tensorExp.children.e1, pattern->e1);
168     default:
169       llvm_unreachable("Unhandled Kind");
170     }
171   }
172 
173   unsigned numTensors;
174   unsigned numLoops;
175   Merger merger;
176 };
177 
178 class MergerTest3T1L : public MergerTestBase {
179 protected:
180   // Our three tensors (two inputs, one output).
181   const unsigned t0 = 0, t1 = 1, t2 = 2;
182 
183   // Our single loop.
184   const unsigned l0 = 0;
185 
186   MergerTest3T1L() : MergerTestBase(3, 1) {
187     // Tensor 0: sparse input vector.
188     merger.addExp(Kind::kTensor, t0, -1u);
189     merger.setDim(t0, l0, Dim::kSparse);
190 
191     // Tensor 1: sparse input vector.
192     merger.addExp(Kind::kTensor, t1, -1u);
193     merger.setDim(t1, l0, Dim::kSparse);
194 
195     // Tensor 2: dense output vector.
196     merger.addExp(Kind::kTensor, t2, -1u);
197     merger.setDim(t2, l0, Dim::kDense);
198   }
199 };
200 
201 } // anonymous namespace
202 
203 /// Vector addition of 2 vectors, i.e.:
204 ///   a(i) = b(i) + c(i)
205 /// which should form the 3 lattice points
206 /// {
207 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
208 ///   lat( i_00 / tensor_0 )
209 ///   lat( i_01 / tensor_1 )
210 /// }
211 /// and after optimization, will reduce to the 2 lattice points
212 /// {
213 ///   lat( i_00 i_01 / (tensor_0 + tensor_1) )
214 ///   lat( i_00 / tensor_0 )
215 /// }
216 TEST_F(MergerTest3T1L, VectorAdd2) {
217   // Construct expression.
218   auto e = addf(tensor(t0), tensor(t1));
219 
220   // Build lattices and check.
221   auto s = merger.buildLattices(e, l0);
222   expectNumLatPoints(s, 3);
223   expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
224                  loopsToBits({{l0, t0}, {l0, t1}}));
225   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
226                             loopsToBits({{l0, t0}}));
227   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
228                             loopsToBits({{l0, t1}}));
229 
230   // Optimize lattices and check.
231   s = merger.optimizeSet(s);
232   expectNumLatPoints(s, 3);
233   expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)),
234                  loopsToBits({{l0, t0}, {l0, t1}}));
235   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0),
236                             loopsToBits({{l0, t0}}));
237   expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1),
238                             loopsToBits({{l0, t1}}));
239 }
240 
241 /// Vector multiplication of 2 vectors, i.e.:
242 ///   a(i) = b(i) * c(i)
243 /// which should form the single lattice point
244 /// {
245 ///   lat( i_00 i_01 / (tensor_0 * tensor_1) )
246 /// }
247 TEST_F(MergerTest3T1L, VectorMul2) {
248   // Construct expression.
249   auto e = mulf(t0, t1);
250 
251   // Build lattices and check.
252   auto s = merger.buildLattices(e, l0);
253   expectNumLatPoints(s, 1);
254   expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
255                  loopsToBits({{l0, t0}, {l0, t1}}));
256 
257   // Optimize lattices and check.
258   s = merger.optimizeSet(s);
259   expectNumLatPoints(s, 1);
260   expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)),
261                  loopsToBits({{l0, t0}, {l0, t1}}));
262 }
263