#include "mlir/Dialect/SparseTensor/Utils/Merger.h" #include "gmock/gmock.h" #include "gtest/gtest.h" #include using namespace mlir::sparse_tensor; namespace { /// Simple recursive data structure used to match expressions in Mergers. struct Pattern { Kind kind; /// Expressions representing tensors simply have a tensor number. unsigned tensorNum; /// Tensor operations point to their children. std::shared_ptr e0; std::shared_ptr e1; /// Constructors. /// Rather than using these, please use the readable helper constructor /// functions below to make tests more readable. Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {} Pattern(Kind kind, std::shared_ptr e0, std::shared_ptr e1) : kind(kind), e0(e0), e1(e1) { assert(kind >= Kind::kMulF); assert(e0 && e1); } }; /// /// Readable Pattern builder functions. /// These should be preferred over the actual constructors. /// static std::shared_ptr tensorPattern(unsigned tensorNum) { return std::make_shared(tensorNum); } static std::shared_ptr addfPattern(std::shared_ptr e0, std::shared_ptr e1) { return std::make_shared(Kind::kAddF, e0, e1); } static std::shared_ptr mulfPattern(std::shared_ptr e0, std::shared_ptr e1) { return std::make_shared(Kind::kMulF, e0, e1); } class MergerTestBase : public ::testing::Test { protected: MergerTestBase(unsigned numTensors, unsigned numLoops) : numTensors(numTensors), numLoops(numLoops), merger(numTensors, numLoops) {} /// /// Expression construction helpers. /// unsigned tensor(unsigned tensor) { return merger.addExp(Kind::kTensor, tensor); } unsigned addf(unsigned e0, unsigned e1) { return merger.addExp(Kind::kAddF, e0, e1); } unsigned mulf(unsigned e0, unsigned e1) { return merger.addExp(Kind::kMulF, e0, e1); } /// /// Comparison helpers. /// /// For readability of tests. unsigned lat(unsigned lat) { return lat; } /// Returns true if a lattice point with an expression matching the given /// pattern and bits matching the given bits is present in lattice points /// [p, p+n) of lattice set s. This is useful for testing partial ordering /// constraints between lattice points. We generally know how contiguous /// groups of lattice points should be ordered with respect to other groups, /// but there is no required ordering within groups. bool latPointWithinRange(unsigned s, unsigned p, unsigned n, std::shared_ptr pattern, llvm::BitVector bits) { for (unsigned i = p; i < p + n; ++i) { if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) && compareBits(s, i, bits)) return true; } return false; } /// Wrapper over latPointWithinRange for readability of tests. void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n, std::shared_ptr pattern, llvm::BitVector bits) { EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits)); } /// Wrapper over expectLatPointWithinRange for a single lat point. void expectLatPoint(unsigned s, unsigned p, std::shared_ptr pattern, llvm::BitVector bits) { EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits)); } /// Converts a vector of (loop, tensor) pairs to a bitvector with the /// corresponding bits set. llvm::BitVector loopsToBits(std::vector> loops) { llvm::BitVector testBits = llvm::BitVector(numTensors + 1, false); for (auto l : loops) { auto loop = std::get<0>(l); auto tensor = std::get<1>(l); testBits.set(numTensors * loop + tensor); } return testBits; } /// Returns true if the bits of lattice point p in set s match the given bits. bool compareBits(unsigned s, unsigned p, llvm::BitVector bits) { return merger.lat(merger.set(s)[p]).bits == bits; } /// Check that there are n lattice points in set s. void expectNumLatPoints(unsigned s, unsigned n) { EXPECT_THAT(merger.set(s).size(), n); } /// Compares expressions for equality. Equality is defined recursively as: /// - Two expressions can only be equal if they have the same Kind. /// - Two binary expressions are equal if they have the same Kind and their /// children are equal. /// - Expressions with Kind invariant or tensor are equal if they have the /// same expression id. bool compareExpression(unsigned e, std::shared_ptr pattern) { auto tensorExp = merger.exp(e); if (tensorExp.kind != pattern->kind) return false; assert(tensorExp.kind != Kind::kInvariant && "Invariant comparison not yet supported"); switch (tensorExp.kind) { case Kind::kTensor: return tensorExp.tensor == pattern->tensorNum; case Kind::kAbsF: case Kind::kCeilF: case Kind::kFloorF: case Kind::kNegF: case Kind::kNegI: return compareExpression(tensorExp.children.e0, pattern->e0); case Kind::kMulF: case Kind::kMulI: case Kind::kDivF: case Kind::kDivS: case Kind::kDivU: case Kind::kAddF: case Kind::kAddI: case Kind::kSubF: case Kind::kSubI: case Kind::kAndI: case Kind::kOrI: case Kind::kXorI: return compareExpression(tensorExp.children.e0, pattern->e0) && compareExpression(tensorExp.children.e1, pattern->e1); default: llvm_unreachable("Unhandled Kind"); } } unsigned numTensors; unsigned numLoops; Merger merger; }; class MergerTest3T1L : public MergerTestBase { protected: // Our three tensors (two inputs, one output). const unsigned t0 = 0, t1 = 1, t2 = 2; // Our single loop. const unsigned l0 = 0; MergerTest3T1L() : MergerTestBase(3, 1) { // Tensor 0: sparse input vector. merger.addExp(Kind::kTensor, t0, -1u); merger.setDim(t0, l0, Dim::kSparse); // Tensor 1: sparse input vector. merger.addExp(Kind::kTensor, t1, -1u); merger.setDim(t1, l0, Dim::kSparse); // Tensor 2: dense output vector. merger.addExp(Kind::kTensor, t2, -1u); merger.setDim(t2, l0, Dim::kDense); } }; } // anonymous namespace /// Vector addition of 2 vectors, i.e.: /// a(i) = b(i) + c(i) /// which should form the 3 lattice points /// { /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) /// lat( i_00 / tensor_0 ) /// lat( i_01 / tensor_1 ) /// } /// and after optimization, will reduce to the 2 lattice points /// { /// lat( i_00 i_01 / (tensor_0 + tensor_1) ) /// lat( i_00 / tensor_0 ) /// } TEST_F(MergerTest3T1L, VectorAdd2) { // Construct expression. auto e = addf(tensor(t0), tensor(t1)); // Build lattices and check. auto s = merger.buildLattices(e, l0); expectNumLatPoints(s, 3); expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), loopsToBits({{l0, t0}, {l0, t1}})); expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), loopsToBits({{l0, t0}})); expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), loopsToBits({{l0, t1}})); // Optimize lattices and check. s = merger.optimizeSet(s); expectNumLatPoints(s, 3); expectLatPoint(s, lat(0), addfPattern(tensorPattern(t0), tensorPattern(t1)), loopsToBits({{l0, t0}, {l0, t1}})); expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t0), loopsToBits({{l0, t0}})); expectLatPointWithinRange(s, lat(1), 2, tensorPattern(t1), loopsToBits({{l0, t1}})); } /// Vector multiplication of 2 vectors, i.e.: /// a(i) = b(i) * c(i) /// which should form the single lattice point /// { /// lat( i_00 i_01 / (tensor_0 * tensor_1) ) /// } TEST_F(MergerTest3T1L, VectorMul2) { // Construct expression. auto e = mulf(t0, t1); // Build lattices and check. auto s = merger.buildLattices(e, l0); expectNumLatPoints(s, 1); expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), loopsToBits({{l0, t0}, {l0, t1}})); // Optimize lattices and check. s = merger.optimizeSet(s); expectNumLatPoints(s, 1); expectLatPoint(s, lat(0), mulfPattern(tensorPattern(t0), tensorPattern(t1)), loopsToBits({{l0, t0}, {l0, t1}})); }