1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 2 #include "gmock/gmock.h" 3 #include "gtest/gtest.h" 4 #include <memory> 5 6 using namespace mlir::sparse_tensor; 7 8 namespace { 9 10 /// Simple recursive data structure used to match expressions in Mergers. 11 struct Pattern { 12 Kind kind; 13 14 /// Expressions representing tensors simply have a tensor number. 15 unsigned tensorNum; 16 17 /// Tensor operations point to their children. 18 std::shared_ptr<Pattern> e0; 19 std::shared_ptr<Pattern> e1; 20 21 /// Constructors. 22 /// Rather than using these, please use the readable helper constructor 23 /// functions below to make tests more readable. 24 Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} 25 Pattern(Kind kind, const std::shared_ptr<Pattern> &e0, 26 const std::shared_ptr<Pattern> &e1) 27 : kind(kind), e0(e0), e1(e1) { 28 assert(kind >= Kind::kMulF); 29 assert(e0 && e1); 30 } 31 }; 32 33 /// 34 /// Readable Pattern builder functions. 35 /// These should be preferred over the actual constructors. 36 /// 37 38 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) { 39 return std::make_shared<Pattern>(tensorNum); 40 } 41 42 static std::shared_ptr<Pattern> 43 addfPattern(const std::shared_ptr<Pattern> &e0, 44 const std::shared_ptr<Pattern> &e1) { 45 return std::make_shared<Pattern>(Kind::kAddF, e0, e1); 46 } 47 48 static std::shared_ptr<Pattern> 49 mulfPattern(const std::shared_ptr<Pattern> &e0, 50 const std::shared_ptr<Pattern> &e1) { 51 return std::make_shared<Pattern>(Kind::kMulF, e0, e1); 52 } 53 54 class MergerTestBase : public ::testing::Test { 55 protected: 56 MergerTestBase(unsigned numTensors, unsigned numLoops) 57 : numTensors(numTensors), numLoops(numLoops), 58 merger(numTensors, numLoops) {} 59 60 /// 61 /// Expression construction helpers. 62 /// 63 64 unsigned tensor(unsigned tensor) { 65 return merger.addExp(Kind::kTensor, tensor); 66 } 67 68 unsigned addf(unsigned e0, unsigned e1) { 69 return merger.addExp(Kind::kAddF, e0, e1); 70 } 71 72 unsigned mulf(unsigned e0, unsigned e1) { 73 return merger.addExp(Kind::kMulF, e0, e1); 74 } 75 76 /// 77 /// Comparison helpers. 78 /// 79 80 /// For readability of tests. 81 unsigned lat(unsigned lat) { return lat; } 82 83 /// Returns true if a lattice point with an expression matching the given 84 /// pattern and bits matching the given bits is present in lattice points 85 /// [p, p+n) of lattice set s. This is useful for testing partial ordering 86 /// constraints between lattice points. We generally know how contiguous 87 /// groups of lattice points should be ordered with respect to other groups, 88 /// but there is no required ordering within groups. 89 bool latPointWithinRange(unsigned s, unsigned p, unsigned n, 90 const std::shared_ptr<Pattern> &pattern, 91 const llvm::BitVector &bits) { 92 for (unsigned i = p; i < p + n; ++i) { 93 if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && 94 compareBits(s, i, bits)) 95 return true; 96 } 97 return false; 98 } 99 100 /// Wrapper over latPointWithinRange for readability of tests. 101 void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, 102 const std::shared_ptr<Pattern> &pattern, 103 const llvm::BitVector &bits) { 104 EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits)); 105 } 106 107 /// Wrapper over expectLatPointWithinRange for a single lat point. 108 void expectLatPoint(unsigned s, unsigned p, 109 const std::shared_ptr<Pattern> &pattern, 110 const llvm::BitVector &bits) { 111 EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits)); 112 } 113 114 /// Converts a vector of (loop, tensor) pairs to a bitvector with the 115 /// corresponding bits set. 116 llvm::BitVector 117 loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) { 118 llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false); 119 for (auto l : loops) { 120 auto loop = std::get<0>(l); 121 auto tensor = std::get<1>(l); 122 testBits.set(numTensors * loop + tensor); 123 } 124 return testBits; 125 } 126 127 /// Returns true if the bits of lattice point p in set s match the given bits. 128 bool compareBits(unsigned s, unsigned p, const llvm::BitVector &bits) { 129 return merger.lat(merger.set(s)[p]).bits == bits; 130 } 131 132 /// Check that there are n lattice points in set s. 133 void expectNumLatPoints(unsigned s, unsigned n) { 134 EXPECT_THAT(merger.set(s).size(), n); 135 } 136 137 /// Compares expressions for equality. Equality is defined recursively as: 138 /// - Two expressions can only be equal if they have the same Kind. 139 /// - Two binary expressions are equal if they have the same Kind and their 140 /// children are equal. 141 /// - Expressions with Kind invariant or tensor are equal if they have the 142 /// same expression id. 143 bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) { 144 auto tensorExp = merger.exp(e); 145 if (tensorExp.kind != pattern->kind) 146 return false; 147 assert(tensorExp.kind != Kind::kInvariant && 148 "Invariant comparison not yet supported"); 149 switch (tensorExp.kind) { 150 case Kind::kTensor: 151 return tensorExp.tensor == pattern->tensorNum; 152 case Kind::kAbsF: 153 case Kind::kCeilF: 154 case Kind::kFloorF: 155 case Kind::kNegF: 156 case Kind::kNegI: 157 return compareExpression(tensorExp.children.e0, pattern->e0); 158 case Kind::kMulF: 159 case Kind::kMulI: 160 case Kind::kDivF: 161 case Kind::kDivS: 162 case Kind::kDivU: 163 case Kind::kAddF: 164 case Kind::kAddI: 165 case Kind::kSubF: 166 case Kind::kSubI: 167 case Kind::kAndI: 168 case Kind::kOrI: 169 case Kind::kXorI: 170 return compareExpression(tensorExp.children.e0, pattern->e0) && 171 compareExpression(tensorExp.children.e1, pattern->e1); 172 default: 173 llvm_unreachable("Unhandled Kind"); 174 } 175 } 176 177 unsigned numTensors; 178 unsigned numLoops; 179 Merger merger; 180 }; 181 182 class MergerTest3T1L : public MergerTestBase { 183 protected: 184 // Our three tensors (two inputs, one output). 185 const unsigned t0 = 0, t1 = 1, t2 = 2; 186 187 // Our single loop. 188 const unsigned l0 = 0; 189 190 MergerTest3T1L() : MergerTestBase(3, 1) { 191 // Tensor 0: sparse input vector. 192 merger.addExp(Kind::kTensor, t0, -1u); 193 merger.setDim(t0, l0, Dim::kSparse); 194 195 // Tensor 1: sparse input vector. 196 merger.addExp(Kind::kTensor, t1, -1u); 197 merger.setDim(t1, l0, Dim::kSparse); 198 199 // Tensor 2: dense output vector. 200 merger.addExp(Kind::kTensor, t2, -1u); 201 merger.setDim(t2, l0, Dim::kDense); 202 } 203 }; 204 205 } // namespace 206 207 /// Vector addition of 2 vectors, i.e.: 208 /// a(i) = b(i) + c(i) 209 /// which should form the 3 lattice points 210 /// { 211 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 212 /// lat( i_00 / tensor_0 ) 213 /// lat( i_01 / tensor_1 ) 214 /// } 215 /// and after optimization, will reduce to the 2 lattice points 216 /// { 217 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 218 /// lat( i_00 / tensor_0 ) 219 /// } 220 TEST_F(MergerTest3T1L, VectorAdd2) { 221 // Construct expression. 222 auto e = addf(tensor(t0), tensor(t1)); 223 224 // Build lattices and check. 225 auto s = merger.buildLattices(e, l0); 226 expectNumLatPoints(s, 3); 227 expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), 228 loopsToBits({{l0, t0}, {l0, t1}})); 229 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), 230 loopsToBits({{l0, t0}})); 231 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), 232 loopsToBits({{l0, t1}})); 233 234 // Optimize lattices and check. 235 s = merger.optimizeSet(s); 236 expectNumLatPoints(s, 3); 237 expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), 238 loopsToBits({{l0, t0}, {l0, t1}})); 239 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), 240 loopsToBits({{l0, t0}})); 241 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), 242 loopsToBits({{l0, t1}})); 243 } 244 245 /// Vector multiplication of 2 vectors, i.e.: 246 /// a(i) = b(i) * c(i) 247 /// which should form the single lattice point 248 /// { 249 /// lat( i_00 i_01 / (tensor_0 * tensor_1) ) 250 /// } 251 TEST_F(MergerTest3T1L, VectorMul2) { 252 // Construct expression. 253 auto e = mulf(t0, t1); 254 255 // Build lattices and check. 256 auto s = merger.buildLattices(e, l0); 257 expectNumLatPoints(s, 1); 258 expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), 259 loopsToBits({{l0, t0}, {l0, t1}})); 260 261 // Optimize lattices and check. 262 s = merger.optimizeSet(s); 263 expectNumLatPoints(s, 1); 264 expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), 265 loopsToBits({{l0, t0}, {l0, t1}})); 266 } 267