140843347SGus Smith #include "mlir/Dialect/SparseTensor/Utils/Merger.h"
2*66ae1d60SPeiming Liu #include "llvm/Support/Compiler.h"
340843347SGus Smith #include "gmock/gmock.h"
440843347SGus Smith #include "gtest/gtest.h"
540843347SGus Smith #include <memory>
640843347SGus Smith
76842ec42SRiver Riddle using namespace mlir;
840843347SGus Smith using namespace mlir::sparse_tensor;
940843347SGus Smith
10*66ae1d60SPeiming Liu // Silence 'warning C4002: 'too many arguments for function-liked macro
11*66ae1d60SPeiming Liu // invocation'
12*66ae1d60SPeiming Liu // as MSVC handles ##__VA_ARGS__ differently as gcc/clang
13*66ae1d60SPeiming Liu
14*66ae1d60SPeiming Liu #if defined(_MSC_VER) && !defined(__clang__)
15*66ae1d60SPeiming Liu #pragma warning(push)
16*66ae1d60SPeiming Liu #pragma warning(disable : 4002)
17*66ae1d60SPeiming Liu #endif
18*66ae1d60SPeiming Liu
1940843347SGus Smith namespace {
2040843347SGus Smith
21*66ae1d60SPeiming Liu ///
22*66ae1d60SPeiming Liu /// Defines macros to iterate binary and the combination of binary operations.
23*66ae1d60SPeiming Liu ///
24*66ae1d60SPeiming Liu
25*66ae1d60SPeiming Liu #define FOREVERY_BINOP(DO) \
26*66ae1d60SPeiming Liu DO(mulf, Kind::kMulF) \
27*66ae1d60SPeiming Liu DO(mulc, Kind::kMulC) \
28*66ae1d60SPeiming Liu DO(muli, Kind::kMulI) \
29*66ae1d60SPeiming Liu DO(addf, Kind::kAddF) \
30*66ae1d60SPeiming Liu DO(addc, Kind::kAddC) \
31*66ae1d60SPeiming Liu DO(addi, Kind::kAddI) \
32*66ae1d60SPeiming Liu DO(subf, Kind::kSubF) \
33*66ae1d60SPeiming Liu DO(subc, Kind::kSubC) \
34*66ae1d60SPeiming Liu DO(subi, Kind::kSubI) \
35*66ae1d60SPeiming Liu DO(andi, Kind::kAndI) \
36*66ae1d60SPeiming Liu DO(xori, Kind::kXorI) \
37*66ae1d60SPeiming Liu DO(ori, Kind::kOrI)
38*66ae1d60SPeiming Liu
39*66ae1d60SPeiming Liu // TODO: Disjunctive binary operations that need special handling are not
40*66ae1d60SPeiming Liu // included, e.g., Division are not tested (for now) as it need a constant
41*66ae1d60SPeiming Liu // non-zero dividend.
42*66ae1d60SPeiming Liu // ##__VA_ARGS__ handles cases when __VA_ARGS__ is empty.
43*66ae1d60SPeiming Liu #define FOREVERY_COMMON_DISJ_BINOP(TEST, ...) \
44*66ae1d60SPeiming Liu TEST(addf, ##__VA_ARGS__) \
45*66ae1d60SPeiming Liu TEST(addc, ##__VA_ARGS__) \
46*66ae1d60SPeiming Liu TEST(addi, ##__VA_ARGS__) \
47*66ae1d60SPeiming Liu TEST(xori, ##__VA_ARGS__) \
48*66ae1d60SPeiming Liu TEST(ori, ##__VA_ARGS__)
49*66ae1d60SPeiming Liu
50*66ae1d60SPeiming Liu // TODO: Conjunctive binary operations that need special handling are not
51*66ae1d60SPeiming Liu // included, e.g., substraction yields a different pattern as it is mapped to
52*66ae1d60SPeiming Liu // negate operation.
53*66ae1d60SPeiming Liu #define FOREVERY_COMMON_CONJ_BINOP(TEST, ...) \
54*66ae1d60SPeiming Liu TEST(mulf, ##__VA_ARGS__) \
55*66ae1d60SPeiming Liu TEST(mulc, ##__VA_ARGS__) \
56*66ae1d60SPeiming Liu TEST(muli, ##__VA_ARGS__) \
57*66ae1d60SPeiming Liu TEST(andi, ##__VA_ARGS__)
58*66ae1d60SPeiming Liu
59*66ae1d60SPeiming Liu #define FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(TEST) \
60*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(TEST, addf) \
61*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(TEST, addc) \
62*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(TEST, addi) \
63*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(TEST, xori) \
64*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(TEST, ori)
65*66ae1d60SPeiming Liu
66*66ae1d60SPeiming Liu #define FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(TEST) \
67*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(TEST, mulf) \
68*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(TEST, mulc) \
69*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(TEST, muli) \
70*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(TEST, andi)
71*66ae1d60SPeiming Liu
72*66ae1d60SPeiming Liu #define FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(TEST) \
73*66ae1d60SPeiming Liu FOREVERY_COMMON_DISJ_BINOP(TEST, addf) \
74*66ae1d60SPeiming Liu FOREVERY_COMMON_DISJ_BINOP(TEST, addc) \
75*66ae1d60SPeiming Liu FOREVERY_COMMON_DISJ_BINOP(TEST, addi) \
76*66ae1d60SPeiming Liu FOREVERY_COMMON_DISJ_BINOP(TEST, ori) \
77*66ae1d60SPeiming Liu FOREVERY_COMMON_DISJ_BINOP(TEST, xori)
78*66ae1d60SPeiming Liu
79*66ae1d60SPeiming Liu ///
80*66ae1d60SPeiming Liu /// Helper classes/functions for testing Merger.
81*66ae1d60SPeiming Liu ///
82*66ae1d60SPeiming Liu
8340843347SGus Smith /// Simple recursive data structure used to match expressions in Mergers.
8440843347SGus Smith struct Pattern {
8540843347SGus Smith Kind kind;
8640843347SGus Smith
8740843347SGus Smith /// Expressions representing tensors simply have a tensor number.
8840843347SGus Smith unsigned tensorNum;
8940843347SGus Smith
9040843347SGus Smith /// Tensor operations point to their children.
9140843347SGus Smith std::shared_ptr<Pattern> e0;
9240843347SGus Smith std::shared_ptr<Pattern> e1;
9340843347SGus Smith
9440843347SGus Smith /// Constructors.
9540843347SGus Smith /// Rather than using these, please use the readable helper constructor
9640843347SGus Smith /// functions below to make tests more readable.
Pattern__anonb1ca32140111::Pattern9740843347SGus Smith Pattern(unsigned tensorNum) : kind(Kind::kTensor), tensorNum(tensorNum) {}
Pattern__anonb1ca32140111::Pattern981fc096afSMehdi Amini Pattern(Kind kind, const std::shared_ptr<Pattern> &e0,
991fc096afSMehdi Amini const std::shared_ptr<Pattern> &e1)
10040843347SGus Smith : kind(kind), e0(e0), e1(e1) {
10140843347SGus Smith assert(kind >= Kind::kMulF);
10240843347SGus Smith assert(e0 && e1);
10340843347SGus Smith }
10440843347SGus Smith };
10540843347SGus Smith
10640843347SGus Smith ///
10740843347SGus Smith /// Readable Pattern builder functions.
10840843347SGus Smith /// These should be preferred over the actual constructors.
10940843347SGus Smith ///
11040843347SGus Smith
tensorPattern(unsigned tensorNum)11140843347SGus Smith static std::shared_ptr<Pattern> tensorPattern(unsigned tensorNum) {
11240843347SGus Smith return std::make_shared<Pattern>(tensorNum);
11340843347SGus Smith }
11440843347SGus Smith
115*66ae1d60SPeiming Liu #define IMPL_BINOP_PATTERN(OP, KIND) \
116*66ae1d60SPeiming Liu LLVM_ATTRIBUTE_UNUSED static std::shared_ptr<Pattern> OP##Pattern( \
117*66ae1d60SPeiming Liu const std::shared_ptr<Pattern> &e0, \
118*66ae1d60SPeiming Liu const std::shared_ptr<Pattern> &e1) { \
119*66ae1d60SPeiming Liu return std::make_shared<Pattern>(KIND, e0, e1); \
12040843347SGus Smith }
12140843347SGus Smith
122*66ae1d60SPeiming Liu FOREVERY_BINOP(IMPL_BINOP_PATTERN)
123*66ae1d60SPeiming Liu
124*66ae1d60SPeiming Liu #undef IMPL_BINOP_PATTERN
12540843347SGus Smith
12640843347SGus Smith class MergerTestBase : public ::testing::Test {
12740843347SGus Smith protected:
MergerTestBase(unsigned numTensors,unsigned numLoops)12840843347SGus Smith MergerTestBase(unsigned numTensors, unsigned numLoops)
12940843347SGus Smith : numTensors(numTensors), numLoops(numLoops),
13040843347SGus Smith merger(numTensors, numLoops) {}
13140843347SGus Smith
13240843347SGus Smith ///
13340843347SGus Smith /// Expression construction helpers.
13440843347SGus Smith ///
13540843347SGus Smith
tensor(unsigned tensor)13640843347SGus Smith unsigned tensor(unsigned tensor) {
13740843347SGus Smith return merger.addExp(Kind::kTensor, tensor);
13840843347SGus Smith }
13940843347SGus Smith
140*66ae1d60SPeiming Liu #define IMPL_BINOP_EXPR(OP, KIND) \
141*66ae1d60SPeiming Liu LLVM_ATTRIBUTE_UNUSED unsigned OP##Expr(unsigned e0, unsigned e1) { \
142*66ae1d60SPeiming Liu return merger.addExp(KIND, e0, e1); \
14340843347SGus Smith }
14440843347SGus Smith
FOREVERY_BINOP(IMPL_BINOP_EXPR)145*66ae1d60SPeiming Liu FOREVERY_BINOP(IMPL_BINOP_EXPR)
146*66ae1d60SPeiming Liu
147*66ae1d60SPeiming Liu #undef IMPL_BINOP_EXPR
14840843347SGus Smith
14940843347SGus Smith ///
15040843347SGus Smith /// Comparison helpers.
15140843347SGus Smith ///
15240843347SGus Smith
15340843347SGus Smith /// For readability of tests.
15440843347SGus Smith unsigned lat(unsigned lat) { return lat; }
15540843347SGus Smith
15640843347SGus Smith /// Returns true if a lattice point with an expression matching the given
15740843347SGus Smith /// pattern and bits matching the given bits is present in lattice points
15840843347SGus Smith /// [p, p+n) of lattice set s. This is useful for testing partial ordering
15940843347SGus Smith /// constraints between lattice points. We generally know how contiguous
16040843347SGus Smith /// groups of lattice points should be ordered with respect to other groups,
16140843347SGus Smith /// but there is no required ordering within groups.
162*66ae1d60SPeiming Liu /// If simple is true, then compare the lat.simple field instead to test the
163*66ae1d60SPeiming Liu /// result after optimization
latPointWithinRange(unsigned s,unsigned p,unsigned n,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple)16440843347SGus Smith bool latPointWithinRange(unsigned s, unsigned p, unsigned n,
1651fc096afSMehdi Amini const std::shared_ptr<Pattern> &pattern,
166*66ae1d60SPeiming Liu const BitVector &bits, bool simple) {
16740843347SGus Smith for (unsigned i = p; i < p + n; ++i) {
16840843347SGus Smith if (compareExpression(merger.lat(merger.set(s)[i]).exp, pattern) &&
169*66ae1d60SPeiming Liu compareBits(s, i, bits, simple))
17040843347SGus Smith return true;
17140843347SGus Smith }
17240843347SGus Smith return false;
17340843347SGus Smith }
17440843347SGus Smith
17540843347SGus Smith /// Wrapper over latPointWithinRange for readability of tests.
expectLatPointWithinRange(unsigned s,unsigned p,unsigned n,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple=false)17640843347SGus Smith void expectLatPointWithinRange(unsigned s, unsigned p, unsigned n,
1774f415216SMehdi Amini const std::shared_ptr<Pattern> &pattern,
178*66ae1d60SPeiming Liu const BitVector &bits, bool simple = false) {
179*66ae1d60SPeiming Liu EXPECT_TRUE(latPointWithinRange(s, p, n, pattern, bits, simple));
18040843347SGus Smith }
18140843347SGus Smith
18240843347SGus Smith /// Wrapper over expectLatPointWithinRange for a single lat point.
expectLatPoint(unsigned s,unsigned p,const std::shared_ptr<Pattern> & pattern,const BitVector & bits,bool simple=false)1834f415216SMehdi Amini void expectLatPoint(unsigned s, unsigned p,
1844f415216SMehdi Amini const std::shared_ptr<Pattern> &pattern,
185*66ae1d60SPeiming Liu const BitVector &bits, bool simple = false) {
186*66ae1d60SPeiming Liu EXPECT_TRUE(latPointWithinRange(s, p, 1, pattern, bits, simple));
18740843347SGus Smith }
18840843347SGus Smith
18940843347SGus Smith /// Converts a vector of (loop, tensor) pairs to a bitvector with the
19040843347SGus Smith /// corresponding bits set.
191d10d49dcSRiver Riddle BitVector
loopsToBits(const std::vector<std::pair<unsigned,unsigned>> & loops)1921fc096afSMehdi Amini loopsToBits(const std::vector<std::pair<unsigned, unsigned>> &loops) {
193d10d49dcSRiver Riddle BitVector testBits = BitVector(numTensors + 1, false);
19440843347SGus Smith for (auto l : loops) {
19540843347SGus Smith auto loop = std::get<0>(l);
19640843347SGus Smith auto tensor = std::get<1>(l);
19740843347SGus Smith testBits.set(numTensors * loop + tensor);
19840843347SGus Smith }
19940843347SGus Smith return testBits;
20040843347SGus Smith }
20140843347SGus Smith
20240843347SGus Smith /// Returns true if the bits of lattice point p in set s match the given bits.
203*66ae1d60SPeiming Liu /// If simple is true, then compare the lat.simple field instead to test the
204*66ae1d60SPeiming Liu /// result after optimization
compareBits(unsigned s,unsigned p,const BitVector & bits,bool simple)205*66ae1d60SPeiming Liu bool compareBits(unsigned s, unsigned p, const BitVector &bits, bool simple) {
206*66ae1d60SPeiming Liu if (simple)
207*66ae1d60SPeiming Liu return merger.lat(merger.set(s)[p]).simple == bits;
20840843347SGus Smith return merger.lat(merger.set(s)[p]).bits == bits;
20940843347SGus Smith }
21040843347SGus Smith
21140843347SGus Smith /// Check that there are n lattice points in set s.
expectNumLatPoints(unsigned s,unsigned n)21240843347SGus Smith void expectNumLatPoints(unsigned s, unsigned n) {
21340843347SGus Smith EXPECT_THAT(merger.set(s).size(), n);
21440843347SGus Smith }
21540843347SGus Smith
21640843347SGus Smith /// Compares expressions for equality. Equality is defined recursively as:
21706aa6ec8SAart Bik /// - Operations are equal if they have the same kind and children.
21806aa6ec8SAart Bik /// - Leaf tensors are equal if they refer to the same tensor.
compareExpression(unsigned e,const std::shared_ptr<Pattern> & pattern)2191fc096afSMehdi Amini bool compareExpression(unsigned e, const std::shared_ptr<Pattern> &pattern) {
22040843347SGus Smith auto tensorExp = merger.exp(e);
22140843347SGus Smith if (tensorExp.kind != pattern->kind)
22240843347SGus Smith return false;
22340843347SGus Smith switch (tensorExp.kind) {
22406aa6ec8SAart Bik // Leaf.
22506aa6ec8SAart Bik case kTensor:
22640843347SGus Smith return tensorExp.tensor == pattern->tensorNum;
22706aa6ec8SAart Bik case kInvariant:
22806aa6ec8SAart Bik case kIndex:
22906aa6ec8SAart Bik llvm_unreachable("invariant not handled yet");
23006aa6ec8SAart Bik // Unary operations.
23106aa6ec8SAart Bik case kAbsF:
23206aa6ec8SAart Bik case kAbsC:
23306aa6ec8SAart Bik case kCeilF:
23406aa6ec8SAart Bik case kFloorF:
23506aa6ec8SAart Bik case kSqrtF:
23606aa6ec8SAart Bik case kSqrtC:
23706aa6ec8SAart Bik case kExpm1F:
23806aa6ec8SAart Bik case kExpm1C:
23906aa6ec8SAart Bik case kLog1pF:
24006aa6ec8SAart Bik case kLog1pC:
24106aa6ec8SAart Bik case kSinF:
24206aa6ec8SAart Bik case kSinC:
24306aa6ec8SAart Bik case kTanhF:
24406aa6ec8SAart Bik case kTanhC:
24506aa6ec8SAart Bik case kNegF:
24606aa6ec8SAart Bik case kNegC:
24706aa6ec8SAart Bik case kNegI:
24806aa6ec8SAart Bik case kTruncF:
24906aa6ec8SAart Bik case kExtF:
25006aa6ec8SAart Bik case kCastFS:
25106aa6ec8SAart Bik case kCastFU:
25206aa6ec8SAart Bik case kCastSF:
25306aa6ec8SAart Bik case kCastUF:
25406aa6ec8SAart Bik case kCastS:
25506aa6ec8SAart Bik case kCastU:
25606aa6ec8SAart Bik case kCastIdx:
25706aa6ec8SAart Bik case kTruncI:
25806aa6ec8SAart Bik case kCIm:
25906aa6ec8SAart Bik case kCRe:
26006aa6ec8SAart Bik case kBitCast:
26106aa6ec8SAart Bik case kBinaryBranch:
26206aa6ec8SAart Bik case kUnary:
26306aa6ec8SAart Bik case kShlI:
26406aa6ec8SAart Bik case kBinary:
265123e8dfcSAart Bik return compareExpression(tensorExp.children.e0, pattern->e0);
26606aa6ec8SAart Bik // Binary operations.
26706aa6ec8SAart Bik case kMulF:
26806aa6ec8SAart Bik case kMulC:
26906aa6ec8SAart Bik case kMulI:
27006aa6ec8SAart Bik case kDivF:
27106aa6ec8SAart Bik case kDivC:
27206aa6ec8SAart Bik case kDivS:
27306aa6ec8SAart Bik case kDivU:
27406aa6ec8SAart Bik case kAddF:
27506aa6ec8SAart Bik case kAddC:
27606aa6ec8SAart Bik case kAddI:
27706aa6ec8SAart Bik case kSubF:
27806aa6ec8SAart Bik case kSubC:
27906aa6ec8SAart Bik case kSubI:
28006aa6ec8SAart Bik case kAndI:
28106aa6ec8SAart Bik case kOrI:
28206aa6ec8SAart Bik case kXorI:
28306aa6ec8SAart Bik case kShrS:
28406aa6ec8SAart Bik case kShrU:
28540843347SGus Smith return compareExpression(tensorExp.children.e0, pattern->e0) &&
28640843347SGus Smith compareExpression(tensorExp.children.e1, pattern->e1);
28740843347SGus Smith }
28806aa6ec8SAart Bik llvm_unreachable("unexpected kind");
28940843347SGus Smith }
29040843347SGus Smith
29140843347SGus Smith unsigned numTensors;
29240843347SGus Smith unsigned numLoops;
29340843347SGus Smith Merger merger;
29440843347SGus Smith };
29540843347SGus Smith
296*66ae1d60SPeiming Liu ///
297*66ae1d60SPeiming Liu /// Tests with all sparse inputs.
298*66ae1d60SPeiming Liu ///
299*66ae1d60SPeiming Liu
30040843347SGus Smith class MergerTest3T1L : public MergerTestBase {
30140843347SGus Smith protected:
30240843347SGus Smith // Our three tensors (two inputs, one output).
30340843347SGus Smith const unsigned t0 = 0, t1 = 1, t2 = 2;
30440843347SGus Smith
30540843347SGus Smith // Our single loop.
30640843347SGus Smith const unsigned l0 = 0;
30740843347SGus Smith
MergerTest3T1L()30840843347SGus Smith MergerTest3T1L() : MergerTestBase(3, 1) {
30940843347SGus Smith // Tensor 0: sparse input vector.
31040843347SGus Smith merger.addExp(Kind::kTensor, t0, -1u);
31140843347SGus Smith merger.setDim(t0, l0, Dim::kSparse);
31240843347SGus Smith
31340843347SGus Smith // Tensor 1: sparse input vector.
31440843347SGus Smith merger.addExp(Kind::kTensor, t1, -1u);
31540843347SGus Smith merger.setDim(t1, l0, Dim::kSparse);
31640843347SGus Smith
31740843347SGus Smith // Tensor 2: dense output vector.
31840843347SGus Smith merger.addExp(Kind::kTensor, t2, -1u);
31940843347SGus Smith merger.setDim(t2, l0, Dim::kDense);
32040843347SGus Smith }
32140843347SGus Smith };
32240843347SGus Smith
323*66ae1d60SPeiming Liu class MergerTest4T1L : public MergerTestBase {
324*66ae1d60SPeiming Liu protected:
325*66ae1d60SPeiming Liu // Our four tensors (three inputs, one output).
326*66ae1d60SPeiming Liu const unsigned t0 = 0, t1 = 1, t2 = 2, t3 = 3;
327*66ae1d60SPeiming Liu
328*66ae1d60SPeiming Liu // Our single loop.
329*66ae1d60SPeiming Liu const unsigned l0 = 0;
330*66ae1d60SPeiming Liu
MergerTest4T1L()331*66ae1d60SPeiming Liu MergerTest4T1L() : MergerTestBase(4, 1) {
332*66ae1d60SPeiming Liu // Tensor 0: sparse input vector.
333*66ae1d60SPeiming Liu merger.addExp(Kind::kTensor, t0, -1u);
334*66ae1d60SPeiming Liu merger.setDim(t0, l0, Dim::kSparse);
335*66ae1d60SPeiming Liu
336*66ae1d60SPeiming Liu // Tensor 1: sparse input vector.
337*66ae1d60SPeiming Liu merger.addExp(Kind::kTensor, t1, -1u);
338*66ae1d60SPeiming Liu merger.setDim(t1, l0, Dim::kSparse);
339*66ae1d60SPeiming Liu
340*66ae1d60SPeiming Liu // Tensor 2: sparse input vector
341*66ae1d60SPeiming Liu merger.addExp(Kind::kTensor, t2, -1u);
342*66ae1d60SPeiming Liu merger.setDim(t2, l0, Dim::kSparse);
343*66ae1d60SPeiming Liu
344*66ae1d60SPeiming Liu // Tensor 3: dense output vector
345*66ae1d60SPeiming Liu merger.addExp(Kind::kTensor, t3, -1u);
346*66ae1d60SPeiming Liu merger.setDim(t3, l0, Dim::kDense);
347*66ae1d60SPeiming Liu }
348*66ae1d60SPeiming Liu };
349*66ae1d60SPeiming Liu
350*66ae1d60SPeiming Liu ///
351*66ae1d60SPeiming Liu /// Tests with both sparse and dense input.
352*66ae1d60SPeiming Liu ///
353*66ae1d60SPeiming Liu
354*66ae1d60SPeiming Liu class MergerTest3T1LD : public MergerTestBase {
355*66ae1d60SPeiming Liu protected:
356*66ae1d60SPeiming Liu // Our three tensors (two inputs, one output).
357*66ae1d60SPeiming Liu const unsigned t0 = 0, t1 = 1, t2 = 2;
358*66ae1d60SPeiming Liu
359*66ae1d60SPeiming Liu // Our single loop.
360*66ae1d60SPeiming Liu const unsigned l0 = 0;
361*66ae1d60SPeiming Liu
MergerTest3T1LD()362*66ae1d60SPeiming Liu MergerTest3T1LD() : MergerTestBase(3, 1) {
363*66ae1d60SPeiming Liu // Tensor 0: sparse input vector.
364*66ae1d60SPeiming Liu merger.addExp(Kind::kTensor, t0, -1u);
365*66ae1d60SPeiming Liu merger.setDim(t0, l0, Dim::kSparse);
366*66ae1d60SPeiming Liu
367*66ae1d60SPeiming Liu // Tensor 1: dense input vector.
368*66ae1d60SPeiming Liu merger.addExp(Kind::kTensor, t1, -1u);
369*66ae1d60SPeiming Liu merger.setDim(t1, l0, Dim::kDense);
370*66ae1d60SPeiming Liu
371*66ae1d60SPeiming Liu // Tensor 2: dense output vector.
372*66ae1d60SPeiming Liu merger.addExp(Kind::kTensor, t2, -1u);
373*66ae1d60SPeiming Liu merger.setDim(t2, l0, Dim::kDense);
374*66ae1d60SPeiming Liu }
375*66ae1d60SPeiming Liu };
376*66ae1d60SPeiming Liu
377be0a7e9fSMehdi Amini } // namespace
37840843347SGus Smith
379*66ae1d60SPeiming Liu /// Vector addition (disjunction) of 2 vectors. i.e.;
38040843347SGus Smith /// a(i) = b(i) + c(i)
38140843347SGus Smith /// which should form the 3 lattice points
38240843347SGus Smith /// {
38340843347SGus Smith /// lat( i_00 i_01 / (tensor_0 + tensor_1) )
38440843347SGus Smith /// lat( i_00 / tensor_0 )
38540843347SGus Smith /// lat( i_01 / tensor_1 )
38640843347SGus Smith /// }
387*66ae1d60SPeiming Liu /// and after optimization, the lattice points do not change (as there is no
388*66ae1d60SPeiming Liu /// duplicated point and all input vectors are sparse vector).
38940843347SGus Smith /// {
39040843347SGus Smith /// lat( i_00 i_01 / (tensor_0 + tensor_1) )
39140843347SGus Smith /// lat( i_00 / tensor_0 )
392*66ae1d60SPeiming Liu /// lat( i_01 / tensor_1 )
39340843347SGus Smith /// }
394*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_DISJ(OP) \
395*66ae1d60SPeiming Liu TEST_F(MergerTest3T1L, vector_##OP) { \
396*66ae1d60SPeiming Liu auto e = OP##Expr(tensor(t0), tensor(t1)); \
397*66ae1d60SPeiming Liu auto p0 = tensorPattern(t0); \
398*66ae1d60SPeiming Liu auto p1 = tensorPattern(t1); \
399*66ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
400*66ae1d60SPeiming Liu \
401*66ae1d60SPeiming Liu expectNumLatPoints(s, 3); \
402*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
403*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
404*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \
405*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \
406*66ae1d60SPeiming Liu \
407*66ae1d60SPeiming Liu s = merger.optimizeSet(s); \
408*66ae1d60SPeiming Liu expectNumLatPoints(s, 3); \
409*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
410*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}}), true); \
411*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}}), \
412*66ae1d60SPeiming Liu true); \
413*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}}), \
414*66ae1d60SPeiming Liu true); \
41540843347SGus Smith }
41640843347SGus Smith
417*66ae1d60SPeiming Liu FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_DISJ)
418*66ae1d60SPeiming Liu
419*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_DISJ
420*66ae1d60SPeiming Liu
421*66ae1d60SPeiming Liu /// Vector multiplication (conjunction) of 2 vectors, i.e.;
42240843347SGus Smith /// a(i) = b(i) * c(i)
42340843347SGus Smith /// which should form the single lattice point
42440843347SGus Smith /// {
42540843347SGus Smith /// lat( i_00 i_01 / (tensor_0 * tensor_1) )
42640843347SGus Smith /// }
427*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_CONJ(OP) \
428*66ae1d60SPeiming Liu TEST_F(MergerTest3T1L, vector_##OP) { \
429*66ae1d60SPeiming Liu auto e = OP##Expr(t0, t1); \
430*66ae1d60SPeiming Liu auto p0 = tensorPattern(t0); \
431*66ae1d60SPeiming Liu auto p1 = tensorPattern(t1); \
432*66ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
433*66ae1d60SPeiming Liu \
434*66ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
435*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
436*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
437*66ae1d60SPeiming Liu \
438*66ae1d60SPeiming Liu s = merger.optimizeSet(s); \
439*66ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
440*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
441*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}}), true); \
44240843347SGus Smith }
443*66ae1d60SPeiming Liu
444*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_CONJ)
445*66ae1d60SPeiming Liu
446*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_CONJ
447*66ae1d60SPeiming Liu
448*66ae1d60SPeiming Liu /// Vector multiplication (conjunction) then addition (disjunction), i.e.;
449*66ae1d60SPeiming Liu /// a(i) = b(i) * c(i) + d(i);
450*66ae1d60SPeiming Liu /// which should form
451*66ae1d60SPeiming Liu /// {
452*66ae1d60SPeiming Liu /// lat( i_00 i_01 i_02 / (tensor_0 * tensor_1) + tensor_2 )
453*66ae1d60SPeiming Liu /// lat( i_00 i_01 / tensor_0 * tensor_1
454*66ae1d60SPeiming Liu /// lat( i_02 / tensor_2 )
455*66ae1d60SPeiming Liu /// }
456*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_CONJ_DISJ(CONJ, DISJ) \
457*66ae1d60SPeiming Liu TEST_F(MergerTest4T1L, vector_##CONJ##_##DISJ) { \
458*66ae1d60SPeiming Liu auto em = CONJ##Expr(t0, t1); \
459*66ae1d60SPeiming Liu auto e = DISJ##Expr(em, t2); \
460*66ae1d60SPeiming Liu auto p0 = tensorPattern(t0); \
461*66ae1d60SPeiming Liu auto p1 = tensorPattern(t1); \
462*66ae1d60SPeiming Liu auto p2 = tensorPattern(t2); \
463*66ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
464*66ae1d60SPeiming Liu \
465*66ae1d60SPeiming Liu expectNumLatPoints(s, 3); \
466*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
467*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
468*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \
469*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
470*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \
471*66ae1d60SPeiming Liu \
472*66ae1d60SPeiming Liu s = merger.optimizeSet(s); \
473*66ae1d60SPeiming Liu expectNumLatPoints(s, 3); \
474*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), DISJ##Pattern(CONJ##Pattern(p0, p1), p2), \
475*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
476*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 2, CONJ##Pattern(p0, p1), \
477*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
478*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 2, p2, loopsToBits({{l0, t2}})); \
479*66ae1d60SPeiming Liu }
480*66ae1d60SPeiming Liu
481*66ae1d60SPeiming Liu FOREVERY_PAIR_OF_COMMON_CONJ_DISJ_BINOP(IMPL_MERGER_TEST_CONJ_DISJ)
482*66ae1d60SPeiming Liu
483*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_CONJ_DISJ
484*66ae1d60SPeiming Liu
485*66ae1d60SPeiming Liu /// Vector addition (disjunction) then addition (disjunction), i.e.;
486*66ae1d60SPeiming Liu /// a(i) = b(i) + c(i) + d(i)
487*66ae1d60SPeiming Liu /// which should form
488*66ae1d60SPeiming Liu /// {
489*66ae1d60SPeiming Liu /// lat( i_00 i_01 i_02 / (tensor_0 + tensor_1) + tensor_2 )
490*66ae1d60SPeiming Liu /// lat( i_02 i_01 / tensor_2 + tensor_1 )
491*66ae1d60SPeiming Liu /// lat( i_02 i_00 / tensor_2 + tensor_0 )
492*66ae1d60SPeiming Liu /// lat( i_01 i_00 / tensor_1 + tensor_0 )
493*66ae1d60SPeiming Liu /// lat( i_02 / tensor_2 )
494*66ae1d60SPeiming Liu /// lat( i_01 / tensor_1 )
495*66ae1d60SPeiming Liu /// lat( i_00 / tensor_0 )
496*66ae1d60SPeiming Liu /// }
497*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_DISJ_DISJ(DISJ1, DISJ2) \
498*66ae1d60SPeiming Liu TEST_F(MergerTest4T1L, Vector_##DISJ1##_##DISJ2) { \
499*66ae1d60SPeiming Liu auto em = DISJ1##Expr(t0, t1); \
500*66ae1d60SPeiming Liu auto e = DISJ2##Expr(em, t2); \
501*66ae1d60SPeiming Liu auto p0 = tensorPattern(t0); \
502*66ae1d60SPeiming Liu auto p1 = tensorPattern(t1); \
503*66ae1d60SPeiming Liu auto p2 = tensorPattern(t2); \
504*66ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
505*66ae1d60SPeiming Liu \
506*66ae1d60SPeiming Liu expectNumLatPoints(s, 7); \
507*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
508*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
509*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \
510*66ae1d60SPeiming Liu loopsToBits({{l0, t1}, {l0, t2}})); \
511*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \
512*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t2}})); \
513*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \
514*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
515*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \
516*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \
517*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \
518*66ae1d60SPeiming Liu \
519*66ae1d60SPeiming Liu s = merger.optimizeSet(s); \
520*66ae1d60SPeiming Liu expectNumLatPoints(s, 7); \
521*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), DISJ2##Pattern(DISJ1##Pattern(p0, p1), p2), \
522*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
523*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p1, p2), \
524*66ae1d60SPeiming Liu loopsToBits({{l0, t1}, {l0, t2}})); \
525*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, DISJ2##Pattern(p0, p2), \
526*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t2}})); \
527*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, DISJ1##Pattern(p0, p1), \
528*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
529*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, p2, loopsToBits({{l0, t2}})); \
530*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, p1, loopsToBits({{l0, t1}})); \
531*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 6, p0, loopsToBits({{l0, t0}})); \
532*66ae1d60SPeiming Liu }
533*66ae1d60SPeiming Liu
534*66ae1d60SPeiming Liu FOREVERY_PAIR_OF_COMMON_DISJ_DISJ_BINOP(IMPL_MERGER_TEST_DISJ_DISJ)
535*66ae1d60SPeiming Liu
536*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_DISJ_DISJ
537*66ae1d60SPeiming Liu
538*66ae1d60SPeiming Liu /// Vector multiplication (conjunction) then multiplication (conjunction), i.e.;
539*66ae1d60SPeiming Liu /// a(i) = b(i) * c(i) * d(i);
540*66ae1d60SPeiming Liu /// which should form
541*66ae1d60SPeiming Liu /// {
542*66ae1d60SPeiming Liu /// lat( i_00 i_01 i_02 / tensor_0 * tensor_1 * tensor_2 )
543*66ae1d60SPeiming Liu /// }
544*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_CONJ_CONJ(CONJ1, CONJ2) \
545*66ae1d60SPeiming Liu TEST_F(MergerTest4T1L, vector_##CONJ1##_##CONJ2) { \
546*66ae1d60SPeiming Liu auto em = CONJ1##Expr(t0, t1); \
547*66ae1d60SPeiming Liu auto e = CONJ2##Expr(em, t2); \
548*66ae1d60SPeiming Liu auto p0 = tensorPattern(t0); \
549*66ae1d60SPeiming Liu auto p1 = tensorPattern(t1); \
550*66ae1d60SPeiming Liu auto p2 = tensorPattern(t2); \
551*66ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
552*66ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
553*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
554*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}})); \
555*66ae1d60SPeiming Liu s = merger.optimizeSet(s); \
556*66ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
557*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), CONJ2##Pattern(CONJ1##Pattern(p0, p1), p2), \
558*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}, {l0, t2}}), true); \
559*66ae1d60SPeiming Liu }
560*66ae1d60SPeiming Liu
561*66ae1d60SPeiming Liu FOREVERY_PAIR_OF_COMMON_CONJ_CONJ_BINOP(IMPL_MERGER_TEST_CONJ_CONJ)
562*66ae1d60SPeiming Liu
563*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_CONJ_CONJ
564*66ae1d60SPeiming Liu
565*66ae1d60SPeiming Liu /// Vector addition (disjunction) of 2 vectors, i.e.;
566*66ae1d60SPeiming Liu /// a(i) = b(i) + c(i)
567*66ae1d60SPeiming Liu /// which should form the 3 lattice points
568*66ae1d60SPeiming Liu /// {
569*66ae1d60SPeiming Liu /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) )
570*66ae1d60SPeiming Liu /// lat( i_00 / sparse_tensor_0 )
571*66ae1d60SPeiming Liu /// lat( i_01 / dense_tensor_1 )
572*66ae1d60SPeiming Liu /// }
573*66ae1d60SPeiming Liu /// which should be optimized to
574*66ae1d60SPeiming Liu /// {
575*66ae1d60SPeiming Liu /// lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ) (not singleton)
576*66ae1d60SPeiming Liu /// lat( i_01 / dense_tensor_0 ) (no sparse dimension)
577*66ae1d60SPeiming Liu /// }
578*66ae1d60SPeiming Liu ///
579*66ae1d60SPeiming Liu /// lat( i_00 / sparse_tensor_0 ) should be opted out as it only has dense diff
580*66ae1d60SPeiming Liu /// with lat( i_00 i_01 / (sparse_tensor_0 + dense_tensor_1) ).
581*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_OPTIMIZED_DISJ(OP) \
582*66ae1d60SPeiming Liu TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
583*66ae1d60SPeiming Liu auto e = OP##Expr(tensor(t0), tensor(t1)); \
584*66ae1d60SPeiming Liu auto p0 = tensorPattern(t0); \
585*66ae1d60SPeiming Liu auto p1 = tensorPattern(t1); \
586*66ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
587*66ae1d60SPeiming Liu \
588*66ae1d60SPeiming Liu expectNumLatPoints(s, 3); \
589*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
590*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
591*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 2, p0, loopsToBits({{l0, t0}})); \
592*66ae1d60SPeiming Liu expectLatPointWithinRange(s, lat(1), 2, p1, loopsToBits({{l0, t1}})); \
593*66ae1d60SPeiming Liu \
594*66ae1d60SPeiming Liu s = merger.optimizeSet(s); \
595*66ae1d60SPeiming Liu expectNumLatPoints(s, 2); \
596*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
597*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}}), true); \
598*66ae1d60SPeiming Liu expectLatPoint(s, lat(1), p1, loopsToBits({{l0, t1}}), true); \
599*66ae1d60SPeiming Liu }
600*66ae1d60SPeiming Liu
601*66ae1d60SPeiming Liu FOREVERY_COMMON_DISJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_DISJ)
602*66ae1d60SPeiming Liu
603*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
604*66ae1d60SPeiming Liu
605*66ae1d60SPeiming Liu /// Vector multiplication (conjunction) of 2 vectors, i.e.:
606*66ae1d60SPeiming Liu /// a(i) = b(i) * c(i)
607*66ae1d60SPeiming Liu /// which should form the single lattice point
608*66ae1d60SPeiming Liu /// {
609*66ae1d60SPeiming Liu /// lat( i_00 i_01 / (sparse_tensor_0 * dense_tensor_1) )
610*66ae1d60SPeiming Liu /// }
611*66ae1d60SPeiming Liu /// it should be optimized to
612*66ae1d60SPeiming Liu /// {
613*66ae1d60SPeiming Liu /// lat( i_00 / (sparse_tensor_0 * dense_tensor_1) )
614*66ae1d60SPeiming Liu /// }
615*66ae1d60SPeiming Liu /// since i_01 is a dense dimension.
616*66ae1d60SPeiming Liu #define IMPL_MERGER_TEST_OPTIMIZED_CONJ(OP) \
617*66ae1d60SPeiming Liu TEST_F(MergerTest3T1LD, vector_opted_##OP) { \
618*66ae1d60SPeiming Liu auto e = OP##Expr(t0, t1); \
619*66ae1d60SPeiming Liu auto p0 = tensorPattern(t0); \
620*66ae1d60SPeiming Liu auto p1 = tensorPattern(t1); \
621*66ae1d60SPeiming Liu auto s = merger.buildLattices(e, l0); \
622*66ae1d60SPeiming Liu \
623*66ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
624*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), OP##Pattern(p0, p1), \
625*66ae1d60SPeiming Liu loopsToBits({{l0, t0}, {l0, t1}})); \
626*66ae1d60SPeiming Liu \
627*66ae1d60SPeiming Liu s = merger.optimizeSet(s); \
628*66ae1d60SPeiming Liu expectNumLatPoints(s, 1); \
629*66ae1d60SPeiming Liu expectLatPoint(s, lat(0), OP##Pattern(p0, p1), loopsToBits({{l0, t0}}), \
630*66ae1d60SPeiming Liu true); \
631*66ae1d60SPeiming Liu }
632*66ae1d60SPeiming Liu
633*66ae1d60SPeiming Liu FOREVERY_COMMON_CONJ_BINOP(IMPL_MERGER_TEST_OPTIMIZED_CONJ)
634*66ae1d60SPeiming Liu
635*66ae1d60SPeiming Liu #undef IMPL_MERGER_TEST_OPTIMIZED_CONJ
636*66ae1d60SPeiming Liu
637*66ae1d60SPeiming Liu // TODO: mult-dim tests
638*66ae1d60SPeiming Liu
639*66ae1d60SPeiming Liu // restore warning status
640*66ae1d60SPeiming Liu #if defined(_MSC_VER) && !defined(__clang__)
641*66ae1d60SPeiming Liu #pragma warning(pop)
642*66ae1d60SPeiming Liu #endif
643