1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 2 #include "gmock/gmock.h" 3 #include "gtest/gtest.h" 4 #include <memory> 5 6 using namespace mlir; 7 using namespace mlir::sparse_tensor; 8 9 namespace { 10 11 /// Simple recursive data structure used to match expressions in Mergers. 12 struct Pattern { 13 Kind kind; 14 15 /// Expressions representing tensors simply have a tensor number. 16 unsigned tensorNum; 17 18 /// Tensor operations point to their children. 19 std::shared_ptr<Pattern> e0; 20 std::shared_ptr<Pattern> e1; 21 22 /// Constructors. 23 /// Rather than using these, please use the readable helper constructor 24 /// functions below to make tests more readable. 25 Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} 26 Pattern(Kind kind, const std::shared_ptr<Pattern> &e0, 27 const std::shared_ptr<Pattern> &e1) 28 : kind(kind), e0(e0), e1(e1) { 29 assert(kind >= Kind::kMulF); 30 assert(e0 && e1); 31 } 32 }; 33 34 /// 35 /// Readable Pattern builder functions. 36 /// These should be preferred over the actual constructors. 37 /// 38 39 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) { 40 return std::make_shared<Pattern>(tensorNum); 41 } 42 43 static std::shared_ptr<Pattern> 44 addfPattern(const std::shared_ptr<Pattern> &e0, 45 const std::shared_ptr<Pattern> &e1) { 46 return std::make_shared<Pattern>(Kind::kAddF, e0, e1); 47 } 48 49 static std::shared_ptr<Pattern> 50 mulfPattern(const std::shared_ptr<Pattern> &e0, 51 const std::shared_ptr<Pattern> &e1) { 52 return std::make_shared<Pattern>(Kind::kMulF, e0, e1); 53 } 54 55 class MergerTestBase : public ::testing::Test { 56 protected: 57 MergerTestBase(unsigned numTensors, unsigned numLoops) 58 : numTensors(numTensors), numLoops(numLoops), 59 merger(numTensors, numLoops) {} 60 61 /// 62 /// Expression construction helpers. 63 /// 64 65 unsigned tensor(unsigned tensor) { 66 return merger.addExp(Kind::kTensor, tensor); 67 } 68 69 unsigned addf(unsigned e0, unsigned e1) { 70 return merger.addExp(Kind::kAddF, e0, e1); 71 } 72 73 unsigned mulf(unsigned e0, unsigned e1) { 74 return merger.addExp(Kind::kMulF, e0, e1); 75 } 76 77 /// 78 /// Comparison helpers. 79 /// 80 81 /// For readability of tests. 82 unsigned lat(unsigned lat) { return lat; } 83 84 /// Returns true if a lattice point with an expression matching the given 85 /// pattern and bits matching the given bits is present in lattice points 86 /// [p, p+n) of lattice set s. This is useful for testing partial ordering 87 /// constraints between lattice points. We generally know how contiguous 88 /// groups of lattice points should be ordered with respect to other groups, 89 /// but there is no required ordering within groups. 90 bool latPointWithinRange(unsigned s, unsigned p, unsigned n, 91 const std::shared_ptr<Pattern> &pattern, 92 const BitVector &bits) { 93 for (unsigned i = p; i < p + n; ++i) { 94 if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && 95 compareBits(s, i, bits)) 96 return true; 97 } 98 return false; 99 } 100 101 /// Wrapper over latPointWithinRange for readability of tests. 102 void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, 103 const std::shared_ptr<Pattern> &pattern, 104 const BitVector &bits) { 105 EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits)); 106 } 107 108 /// Wrapper over expectLatPointWithinRange for a single lat point. 109 void expectLatPoint(unsigned s, unsigned p, 110 const std::shared_ptr<Pattern> &pattern, 111 const BitVector &bits) { 112 EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits)); 113 } 114 115 /// Converts a vector of (loop, tensor) pairs to a bitvector with the 116 /// corresponding bits set. 117 BitVector 118 loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) { 119 BitVector testBits = BitVector(numTensors + 1, false); 120 for (auto l : loops) { 121 auto loop = std::get<0>(l); 122 auto tensor = std::get<1>(l); 123 testBits.set(numTensors * loop + tensor); 124 } 125 return testBits; 126 } 127 128 /// Returns true if the bits of lattice point p in set s match the given bits. 129 bool compareBits(unsigned s, unsigned p, const BitVector &bits) { 130 return merger.lat(merger.set(s)[p]).bits == bits; 131 } 132 133 /// Check that there are n lattice points in set s. 134 void expectNumLatPoints(unsigned s, unsigned n) { 135 EXPECT_THAT(merger.set(s).size(), n); 136 } 137 138 /// Compares expressions for equality. Equality is defined recursively as: 139 /// - Two expressions can only be equal if they have the same Kind. 140 /// - Two binary expressions are equal if they have the same Kind and their 141 /// children are equal. 142 /// - Expressions with Kind invariant or tensor are equal if they have the 143 /// same expression id. 144 bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) { 145 auto tensorExp = merger.exp(e); 146 if (tensorExp.kind != pattern->kind) 147 return false; 148 assert(tensorExp.kind != Kind::kInvariant && 149 "Invariant comparison not yet supported"); 150 switch (tensorExp.kind) { 151 case Kind::kTensor: 152 return tensorExp.tensor == pattern->tensorNum; 153 case Kind::kAbsF: 154 case Kind::kCeilF: 155 case Kind::kFloorF: 156 case Kind::kNegF: 157 case Kind::kNegI: 158 return compareExpression(tensorExp.children.e0, pattern->e0); 159 case Kind::kMulF: 160 case Kind::kMulI: 161 case Kind::kDivF: 162 case Kind::kDivS: 163 case Kind::kDivU: 164 case Kind::kAddF: 165 case Kind::kAddI: 166 case Kind::kSubF: 167 case Kind::kSubI: 168 case Kind::kAndI: 169 case Kind::kOrI: 170 case Kind::kXorI: 171 return compareExpression(tensorExp.children.e0, pattern->e0) && 172 compareExpression(tensorExp.children.e1, pattern->e1); 173 default: 174 llvm_unreachable("Unhandled Kind"); 175 } 176 } 177 178 unsigned numTensors; 179 unsigned numLoops; 180 Merger merger; 181 }; 182 183 class MergerTest3T1L : public MergerTestBase { 184 protected: 185 // Our three tensors (two inputs, one output). 186 const unsigned t0 = 0, t1 = 1, t2 = 2; 187 188 // Our single loop. 189 const unsigned l0 = 0; 190 191 MergerTest3T1L() : MergerTestBase(3, 1) { 192 // Tensor 0: sparse input vector. 193 merger.addExp(Kind::kTensor, t0, -1u); 194 merger.setDim(t0, l0, Dim::kSparse); 195 196 // Tensor 1: sparse input vector. 197 merger.addExp(Kind::kTensor, t1, -1u); 198 merger.setDim(t1, l0, Dim::kSparse); 199 200 // Tensor 2: dense output vector. 201 merger.addExp(Kind::kTensor, t2, -1u); 202 merger.setDim(t2, l0, Dim::kDense); 203 } 204 }; 205 206 } // namespace 207 208 /// Vector addition of 2 vectors, i.e.: 209 /// a(i) = b(i) + c(i) 210 /// which should form the 3 lattice points 211 /// { 212 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 213 /// lat( i_00 / tensor_0 ) 214 /// lat( i_01 / tensor_1 ) 215 /// } 216 /// and after optimization, will reduce to the 2 lattice points 217 /// { 218 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 219 /// lat( i_00 / tensor_0 ) 220 /// } 221 TEST_F(MergerTest3T1L, VectorAdd2) { 222 // Construct expression. 223 auto e = addf(tensor(t0), tensor(t1)); 224 225 // Build lattices and check. 226 auto s = merger.buildLattices(e, l0); 227 expectNumLatPoints(s, 3); 228 expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), 229 loopsToBits({{l0, t0}, {l0, t1}})); 230 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), 231 loopsToBits({{l0, t0}})); 232 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), 233 loopsToBits({{l0, t1}})); 234 235 // Optimize lattices and check. 236 s = merger.optimizeSet(s); 237 expectNumLatPoints(s, 3); 238 expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), 239 loopsToBits({{l0, t0}, {l0, t1}})); 240 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), 241 loopsToBits({{l0, t0}})); 242 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), 243 loopsToBits({{l0, t1}})); 244 } 245 246 /// Vector multiplication of 2 vectors, i.e.: 247 /// a(i) = b(i) * c(i) 248 /// which should form the single lattice point 249 /// { 250 /// lat( i_00 i_01 / (tensor_0 * tensor_1) ) 251 /// } 252 TEST_F(MergerTest3T1L, VectorMul2) { 253 // Construct expression. 254 auto e = mulf(t0, t1); 255 256 // Build lattices and check. 257 auto s = merger.buildLattices(e, l0); 258 expectNumLatPoints(s, 1); 259 expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), 260 loopsToBits({{l0, t0}, {l0, t1}})); 261 262 // Optimize lattices and check. 263 s = merger.optimizeSet(s); 264 expectNumLatPoints(s, 1); 265 expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), 266 loopsToBits({{l0, t0}, {l0, t1}})); 267 } 268