1 #include "mlir/Dialect/SparseTensor/Utils/Merger.h" 2 #include "gmock/gmock.h" 3 #include "gtest/gtest.h" 4 #include <memory> 5 6 using namespace mlir; 7 using namespace mlir::sparse_tensor; 8 9 namespace { 10 11 /// Simple recursive data structure used to match expressions in Mergers. 12 struct Pattern { 13 Kind kind; 14 15 /// Expressions representing tensors simply have a tensor number. 16 unsigned tensorNum; 17 18 /// Tensor operations point to their children. 19 std::shared_ptr<Pattern> e0; 20 std::shared_ptr<Pattern> e1; 21 22 /// Constructors. 23 /// Rather than using these, please use the readable helper constructor 24 /// functions below to make tests more readable. 25 Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} 26 Pattern(Kind kind, const std::shared_ptr<Pattern> &e0, 27 const std::shared_ptr<Pattern> &e1) 28 : kind(kind), e0(e0), e1(e1) { 29 assert(kind >= Kind::kMulF); 30 assert(e0 && e1); 31 } 32 }; 33 34 /// 35 /// Readable Pattern builder functions. 36 /// These should be preferred over the actual constructors. 37 /// 38 39 static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) { 40 return std::make_shared<Pattern>(tensorNum); 41 } 42 43 static std::shared_ptr<Pattern> 44 addfPattern(const std::shared_ptr<Pattern> &e0, 45 const std::shared_ptr<Pattern> &e1) { 46 return std::make_shared<Pattern>(Kind::kAddF, e0, e1); 47 } 48 49 static std::shared_ptr<Pattern> 50 mulfPattern(const std::shared_ptr<Pattern> &e0, 51 const std::shared_ptr<Pattern> &e1) { 52 return std::make_shared<Pattern>(Kind::kMulF, e0, e1); 53 } 54 55 class MergerTestBase : public ::testing::Test { 56 protected: 57 MergerTestBase(unsigned numTensors, unsigned numLoops) 58 : numTensors(numTensors), numLoops(numLoops), 59 merger(numTensors, numLoops) {} 60 61 /// 62 /// Expression construction helpers. 63 /// 64 65 unsigned tensor(unsigned tensor) { 66 return merger.addExp(Kind::kTensor, tensor); 67 } 68 69 unsigned addf(unsigned e0, unsigned e1) { 70 return merger.addExp(Kind::kAddF, e0, e1); 71 } 72 73 unsigned mulf(unsigned e0, unsigned e1) { 74 return merger.addExp(Kind::kMulF, e0, e1); 75 } 76 77 /// 78 /// Comparison helpers. 79 /// 80 81 /// For readability of tests. 82 unsigned lat(unsigned lat) { return lat; } 83 84 /// Returns true if a lattice point with an expression matching the given 85 /// pattern and bits matching the given bits is present in lattice points 86 /// [p, p+n) of lattice set s. This is useful for testing partial ordering 87 /// constraints between lattice points. We generally know how contiguous 88 /// groups of lattice points should be ordered with respect to other groups, 89 /// but there is no required ordering within groups. 90 bool latPointWithinRange(unsigned s, unsigned p, unsigned n, 91 const std::shared_ptr<Pattern> &pattern, 92 const BitVector &bits) { 93 for (unsigned i = p; i < p + n; ++i) { 94 if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && 95 compareBits(s, i, bits)) 96 return true; 97 } 98 return false; 99 } 100 101 /// Wrapper over latPointWithinRange for readability of tests. 102 void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, 103 const std::shared_ptr<Pattern> &pattern, 104 const BitVector &bits) { 105 EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits)); 106 } 107 108 /// Wrapper over expectLatPointWithinRange for a single lat point. 109 void expectLatPoint(unsigned s, unsigned p, 110 const std::shared_ptr<Pattern> &pattern, 111 const BitVector &bits) { 112 EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits)); 113 } 114 115 /// Converts a vector of (loop, tensor) pairs to a bitvector with the 116 /// corresponding bits set. 117 BitVector 118 loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) { 119 BitVector testBits = BitVector(numTensors + 1, false); 120 for (auto l : loops) { 121 auto loop = std::get<0>(l); 122 auto tensor = std::get<1>(l); 123 testBits.set(numTensors * loop + tensor); 124 } 125 return testBits; 126 } 127 128 /// Returns true if the bits of lattice point p in set s match the given bits. 129 bool compareBits(unsigned s, unsigned p, const BitVector &bits) { 130 return merger.lat(merger.set(s)[p]).bits == bits; 131 } 132 133 /// Check that there are n lattice points in set s. 134 void expectNumLatPoints(unsigned s, unsigned n) { 135 EXPECT_THAT(merger.set(s).size(), n); 136 } 137 138 /// Compares expressions for equality. Equality is defined recursively as: 139 /// - Operations are equal if they have the same kind and children. 140 /// - Leaf tensors are equal if they refer to the same tensor. 141 bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) { 142 auto tensorExp = merger.exp(e); 143 if (tensorExp.kind != pattern->kind) 144 return false; 145 switch (tensorExp.kind) { 146 // Leaf. 147 case kTensor: 148 return tensorExp.tensor == pattern->tensorNum; 149 case kInvariant: 150 case kIndex: 151 llvm_unreachable("invariant not handled yet"); 152 // Unary operations. 153 case kAbsF: 154 case kAbsC: 155 case kCeilF: 156 case kFloorF: 157 case kSqrtF: 158 case kSqrtC: 159 case kExpm1F: 160 case kExpm1C: 161 case kLog1pF: 162 case kLog1pC: 163 case kSinF: 164 case kSinC: 165 case kTanhF: 166 case kTanhC: 167 case kNegF: 168 case kNegC: 169 case kNegI: 170 case kTruncF: 171 case kExtF: 172 case kCastFS: 173 case kCastFU: 174 case kCastSF: 175 case kCastUF: 176 case kCastS: 177 case kCastU: 178 case kCastIdx: 179 case kTruncI: 180 case kCIm: 181 case kCRe: 182 case kBitCast: 183 case kBinaryBranch: 184 case kUnary: 185 case kShlI: 186 case kBinary: 187 return compareExpression(tensorExp.children.e0, pattern->e0); 188 // Binary operations. 189 case kMulF: 190 case kMulC: 191 case kMulI: 192 case kDivF: 193 case kDivC: 194 case kDivS: 195 case kDivU: 196 case kAddF: 197 case kAddC: 198 case kAddI: 199 case kSubF: 200 case kSubC: 201 case kSubI: 202 case kAndI: 203 case kOrI: 204 case kXorI: 205 case kShrS: 206 case kShrU: 207 return compareExpression(tensorExp.children.e0, pattern->e0) && 208 compareExpression(tensorExp.children.e1, pattern->e1); 209 } 210 llvm_unreachable("unexpected kind"); 211 } 212 213 unsigned numTensors; 214 unsigned numLoops; 215 Merger merger; 216 }; 217 218 class MergerTest3T1L : public MergerTestBase { 219 protected: 220 // Our three tensors (two inputs, one output). 221 const unsigned t0 = 0, t1 = 1, t2 = 2; 222 223 // Our single loop. 224 const unsigned l0 = 0; 225 226 MergerTest3T1L() : MergerTestBase(3, 1) { 227 // Tensor 0: sparse input vector. 228 merger.addExp(Kind::kTensor, t0, -1u); 229 merger.setDim(t0, l0, Dim::kSparse); 230 231 // Tensor 1: sparse input vector. 232 merger.addExp(Kind::kTensor, t1, -1u); 233 merger.setDim(t1, l0, Dim::kSparse); 234 235 // Tensor 2: dense output vector. 236 merger.addExp(Kind::kTensor, t2, -1u); 237 merger.setDim(t2, l0, Dim::kDense); 238 } 239 }; 240 241 } // namespace 242 243 /// Vector addition of 2 vectors, i.e.: 244 /// a(i) = b(i) + c(i) 245 /// which should form the 3 lattice points 246 /// { 247 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 248 /// lat( i_00 / tensor_0 ) 249 /// lat( i_01 / tensor_1 ) 250 /// } 251 /// and after optimization, will reduce to the 2 lattice points 252 /// { 253 /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) 254 /// lat( i_00 / tensor_0 ) 255 /// } 256 TEST_F(MergerTest3T1L, VectorAdd2) { 257 // Construct expression. 258 auto e = addf(tensor(t0), tensor(t1)); 259 260 // Build lattices and check. 261 auto s = merger.buildLattices(e, l0); 262 expectNumLatPoints(s, 3); 263 expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), 264 loopsToBits({{l0, t0}, {l0, t1}})); 265 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), 266 loopsToBits({{l0, t0}})); 267 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), 268 loopsToBits({{l0, t1}})); 269 270 // Optimize lattices and check. 271 s = merger.optimizeSet(s); 272 expectNumLatPoints(s, 3); 273 expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), 274 loopsToBits({{l0, t0}, {l0, t1}})); 275 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), 276 loopsToBits({{l0, t0}})); 277 expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), 278 loopsToBits({{l0, t1}})); 279 } 280 281 /// Vector multiplication of 2 vectors, i.e.: 282 /// a(i) = b(i) * c(i) 283 /// which should form the single lattice point 284 /// { 285 /// lat( i_00 i_01 / (tensor_0 * tensor_1) ) 286 /// } 287 TEST_F(MergerTest3T1L, VectorMul2) { 288 // Construct expression. 289 auto e = mulf(t0, t1); 290 291 // Build lattices and check. 292 auto s = merger.buildLattices(e, l0); 293 expectNumLatPoints(s, 1); 294 expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), 295 loopsToBits({{l0, t0}, {l0, t1}})); 296 297 // Optimize lattices and check. 298 s = merger.optimizeSet(s); 299 expectNumLatPoints(s, 1); 300 expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), 301 loopsToBits({{l0, t0}, {l0, t1}})); 302 } 303