1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 2 #include "gmock/gmock.h" 3 #include "gtest/gtest.h" 4 #include <memory> 5 6 using namespace mlir::sparse_tensor; 7 8 namespace { 9 10 /// Simple recursive data structure used to match expressions in Mergers. 11 struct Pattern { 12 Kind kind; 13 14 /// Expressions representing tensors simply have a tensor number. 15 unsigned tensorNum; 16 17 /// Tensor operations point to their children. 18 std::shared_ptr<Pattern> e0; 19 std::shared_ptr<Pattern> e1; 20 21 /// Constructors. 22 /// Rather than using these, please use the readable helper constructor 23 /// functions below to make tests more readable. 24 Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} 25 Pattern(Kind kind, std::shared_ptr<Pattern> e0, std::shared_ptr<Pattern> e1) 26 : kind(kind), e0(e0), e1(e1) { 27 assert(kind >= Kind::kMulF); 28 assert(e0 && e1); 29 } 30 }; 31 32 /// 33 /// Readable Pattern builder functions. 34 /// These should be preferred over the actual constructors. 35 /// 36 37 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) { 38 return std::make_shared<Pattern>(tensorNum); 39 } 40 41 static std::shared_ptr<Pattern> addfPattern(std::shared_ptr<Pattern> e0, 42 std::shared_ptr<Pattern> e1) { 43 return std::make_shared<Pattern>(Kind::kAddF, e0, e1); 44 } 45 46 static std::shared_ptr<Pattern> mulfPattern(std::shared_ptr<Pattern> e0, 47 std::shared_ptr<Pattern> e1) { 48 return std::make_shared<Pattern>(Kind::kMulF, e0, e1); 49 } 50 51 class MergerTestBase : public ::testing::Test { 52 protected: 53 MergerTestBase(unsigned numTensors, unsigned numLoops) 54 : numTensors(numTensors), numLoops(numLoops), 55 merger(numTensors, numLoops) {} 56 57 /// 58 /// Expression construction helpers. 59 /// 60 61 unsigned tensor(unsigned tensor) { 62 return merger.addExp(Kind::kTensor, tensor); 63 } 64 65 unsigned addf(unsigned e0, unsigned e1) { 66 return merger.addExp(Kind::kAddF, e0, e1); 67 } 68 69 unsigned mulf(unsigned e0, unsigned e1) { 70 return merger.addExp(Kind::kMulF, e0, e1); 71 } 72 73 /// 74 /// Comparison helpers. 75 /// 76 77 /// For readability of tests. 78 unsigned lat(unsigned lat) { return lat; } 79 80 /// Returns true if a lattice point with an expression matching the given 81 /// pattern and bits matching the given bits is present in lattice points 82 /// [p, p+n) of lattice set s. This is useful for testing partial ordering 83 /// constraints between lattice points. We generally know how contiguous 84 /// groups of lattice points should be ordered with respect to other groups, 85 /// but there is no required ordering within groups. 86 bool latPointWithinRange(unsigned s, unsigned p, unsigned n, 87 std::shared_ptr<Pattern> pattern, 88 llvm::BitVector bits) { 89 for (unsigned i = p; i < p + n; ++i) { 90 if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && 91 compareBits(s, i, bits)) 92 return true; 93 } 94 return false; 95 } 96 97 /// Wrapper over latPointWithinRange for readability of tests. 98 void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, 99 std::shared_ptr<Pattern> pattern, 100 llvm::BitVector bits) { 101 EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits)); 102 } 103 104 /// Wrapper over expectLatPointWithinRange for a single lat point. 105 void expectLatPoint(unsigned s, unsigned p, std::shared_ptr<Pattern> pattern, 106 llvm::BitVector bits) { 107 EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits)); 108 } 109 110 /// Converts a vector of (loop, tensor) pairs to a bitvector with the 111 /// corresponding bits set. 112 llvm::BitVector 113 loopsToBits(std::vector<std::pair<unsigned, unsigned>> loops) { 114 llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false); 115 for (auto l : loops) { 116 auto loop = std::get<0>(l); 117 auto tensor = std::get<1>(l); 118 testBits.set(numTensors * loop + tensor); 119 } 120 return testBits; 121 } 122 123 /// Returns true if the bits of lattice point p in set s match the given bits. 124 bool compareBits(unsigned s, unsigned p, llvm::BitVector bits) { 125 return merger.lat(merger.set(s)[p]).bits == bits; 126 } 127 128 /// Check that there are n lattice points in set s. 129 void expectNumLatPoints(unsigned s, unsigned n) { 130 EXPECT_THAT(merger.set(s).size(), n); 131 } 132 133 /// Compares expressions for equality. Equality is defined recursively as: 134 /// - Two expressions can only be equal if they have the same Kind. 135 /// - Two binary expressions are equal if they have the same Kind and their 136 /// children are equal. 137 /// - Expressions with Kind invariant or tensor are equal if they have the 138 /// same expression id. 139 bool compareExpression(unsigned e, std::shared_ptr<Pattern> pattern) { 140 auto tensorExp = merger.exp(e); 141 if (tensorExp.kind != pattern->kind) 142 return false; 143 assert(tensorExp.kind != Kind::kInvariant && 144 "Invariant comparison not yet supported"); 145 switch (tensorExp.kind) { 146 case Kind::kTensor: 147 return tensorExp.tensor == pattern->tensorNum; 148 case Kind::kAbsF: 149 case Kind::kCeilF: 150 case Kind::kFloorF: 151 case Kind::kNegF: 152 case Kind::kNegI: 153 return compareExpression(tensorExp.children.e0, pattern->e0); 154 case Kind::kMulF: 155 case Kind::kMulI: 156 case Kind::kDivF: 157 case Kind::kDivS: 158 case Kind::kDivU: 159 case Kind::kAddF: 160 case Kind::kAddI: 161 case Kind::kSubF: 162 case Kind::kSubI: 163 case Kind::kAndI: 164 case Kind::kOrI: 165 case Kind::kXorI: 166 return compareExpression(tensorExp.children.e0, pattern->e0) && 167 compareExpression(tensorExp.children.e1, pattern->e1); 168 default: 169 llvm_unreachable("Unhandled Kind"); 170 } 171 } 172 173 unsigned numTensors; 174 unsigned numLoops; 175 Merger merger; 176 }; 177 178 class MergerTest3T1L : public MergerTestBase { 179 protected: 180 // Our three tensors (two inputs, one output). 181 const unsigned t0 = 0, t1 = 1, t2 = 2; 182 183 // Our single loop. 184 const unsigned l0 = 0; 185 186 MergerTest3T1L() : MergerTestBase(3, 1) { 187 // Tensor 0: sparse input vector. 188 merger.addExp(Kind::kTensor, t0, -1u); 189 merger.setDim(t0, l0, Dim::kSparse); 190 191 // Tensor 1: sparse input vector. 192 merger.addExp(Kind::kTensor, t1, -1u); 193 merger.setDim(t1, l0, Dim::kSparse); 194 195 // Tensor 2: dense output vector. 196 merger.addExp(Kind::kTensor, t2, -1u); 197 merger.setDim(t2, l0, Dim::kDense); 198 } 199 }; 200 201 } // anonymous namespace 202 203 /// Vector addition of 2 vectors, i.e.: 204 /// a(i) = b(i) + c(i) 205 /// which should form the 3 lattice points 206 /// { 207 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 208 /// lat( i_00 / tensor_0 ) 209 /// lat( i_01 / tensor_1 ) 210 /// } 211 /// and after optimization, will reduce to the 2 lattice points 212 /// { 213 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 214 /// lat( i_00 / tensor_0 ) 215 /// } 216 TEST_F(MergerTest3T1L, VectorAdd2) { 217 // Construct expression. 218 auto e = addf(tensor(t0), tensor(t1)); 219 220 // Build lattices and check. 221 auto s = merger.buildLattices(e, l0); 222 expectNumLatPoints(s, 3); 223 expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), 224 loopsToBits({{l0, t0}, {l0, t1}})); 225 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), 226 loopsToBits({{l0, t0}})); 227 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), 228 loopsToBits({{l0, t1}})); 229 230 // Optimize lattices and check. 231 s = merger.optimizeSet(s); 232 expectNumLatPoints(s, 3); 233 expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), 234 loopsToBits({{l0, t0}, {l0, t1}})); 235 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), 236 loopsToBits({{l0, t0}})); 237 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), 238 loopsToBits({{l0, t1}})); 239 } 240 241 /// Vector multiplication of 2 vectors, i.e.: 242 /// a(i) = b(i) * c(i) 243 /// which should form the single lattice point 244 /// { 245 /// lat( i_00 i_01 / (tensor_0 * tensor_1) ) 246 /// } 247 TEST_F(MergerTest3T1L, VectorMul2) { 248 // Construct expression. 249 auto e = mulf(t0, t1); 250 251 // Build lattices and check. 252 auto s = merger.buildLattices(e, l0); 253 expectNumLatPoints(s, 1); 254 expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), 255 loopsToBits({{l0, t0}, {l0, t1}})); 256 257 // Optimize lattices and check. 258 s = merger.optimizeSet(s); 259 expectNumLatPoints(s, 1); 260 expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), 261 loopsToBits({{l0, t0}, {l0, t1}})); 262 } 263