1 //===- PWMAFunctionTest.cpp - Tests for PWMAFunction ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains tests for PWMAFunction.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "./Utils.h"
14 
15 #include "mlir/Analysis/Presburger/PWMAFunction.h"
16 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
17 #include "mlir/IR/MLIRContext.h"
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 
22 using namespace mlir;
23 using namespace presburger;
24 
25 using testing::ElementsAre;
26 
27 static Matrix makeMatrix(unsigned numRow, unsigned numColumns,
28                          ArrayRef<SmallVector<int64_t, 8>> matrix) {
29   Matrix results(numRow, numColumns);
30   assert(matrix.size() == numRow);
31   for (unsigned i = 0; i < numRow; ++i) {
32     assert(matrix[i].size() == numColumns &&
33            "Output expression has incorrect dimensionality!");
34     for (unsigned j = 0; j < numColumns; ++j)
35       results(i, j) = matrix[i][j];
36   }
37   return results;
38 }
39 
40 /// Construct a PWMAFunction given the dimensionalities and an array describing
41 /// the list of pieces. Each piece is given by a string describing the domain
42 /// and a 2D array that represents the output.
43 static PWMAFunction parsePWMAF(
44     unsigned numInputs, unsigned numOutputs,
45     ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
46         data,
47     unsigned numSymbols = 0) {
48   PWMAFunction result(numInputs - numSymbols, numSymbols, numOutputs);
49   for (const auto &pair : data) {
50     IntegerPolyhedron domain = parsePoly(pair.first);
51     result.addPiece(
52         domain, makeMatrix(numOutputs, domain.getNumIds() + 1, pair.second));
53   }
54   return result;
55 }
56 
57 TEST(PWAFunctionTest, isEqual) {
58   // The output expressions are different but it doesn't matter because they are
59   // equal in this domain.
60   PWMAFunction idAtZeros = parsePWMAF(
61       /*numInputs=*/2, /*numOutputs=*/2,
62       {
63           {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}},             // (x, y).
64           {"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
65           {"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
66       });
67   PWMAFunction idAtZeros2 = parsePWMAF(
68       /*numInputs=*/2, /*numOutputs=*/2,
69       {
70           {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 20, 0}}}, // (x, 20y).
71           {"(x, y) : (y - 1 >= 0, x == 0)", {{30, 0, 0}, {0, 1, 0}}}, //(30x, y)
72           {"(x, y) : (-y - 1 > =0, x == 0)", {{30, 0, 0}, {0, 1, 0}}} //(30x, y)
73       });
74   EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2));
75 
76   PWMAFunction notIdAtZeros = parsePWMAF(
77       /*numInputs=*/2, /*numOutputs=*/2,
78       {
79           {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}},              // (x, y).
80           {"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}},  // (x, 2y)
81           {"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
82       });
83   EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros));
84 
85   // These match at their intersection but one has a bigger domain.
86   PWMAFunction idNoNegNegQuadrant = parsePWMAF(
87       /*numInputs=*/2, /*numOutputs=*/2,
88       {
89           {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}},             // (x, y).
90           {"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
91       });
92   PWMAFunction idOnlyPosX =
93       parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2,
94                  {
95                      {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
96                  });
97   EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX));
98 
99   // Different representations of the same domain.
100   PWMAFunction sumPlusOne = parsePWMAF(
101       /*numInputs=*/2, /*numOutputs=*/1,
102       {
103           {"(x, y) : (x >= 0)", {{1, 1, 1}}},                   // x + y + 1.
104           {"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", {{1, 1, 1}}}, // x + y + 1.
105           {"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 1, 1}}}       // x + y + 1.
106       });
107   PWMAFunction sumPlusOne2 =
108       parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1,
109                  {
110                      {"(x, y) : ()", {{1, 1, 1}}}, // x + y + 1.
111                  });
112   EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2));
113 
114   // Functions with zero input dimensions.
115   PWMAFunction noInputs1 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
116                                       {
117                                           {"() : ()", {{1}}}, // 1.
118                                       });
119   PWMAFunction noInputs2 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
120                                       {
121                                           {"() : ()", {{2}}}, // 1.
122                                       });
123   EXPECT_TRUE(noInputs1.isEqual(noInputs1));
124   EXPECT_FALSE(noInputs1.isEqual(noInputs2));
125 
126   // Mismatched dimensionalities.
127   EXPECT_FALSE(noInputs1.isEqual(sumPlusOne));
128   EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne));
129 
130   // Divisions.
131   // Domain is only multiples of 6; x = 6k for some k.
132   // x + 4(x/2) + 4(x/3) == 26k.
133   PWMAFunction mul2AndMul3 = parsePWMAF(
134       /*numInputs=*/1, /*numOutputs=*/1,
135       {
136           {"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)",
137            {{1, 4, 4, 0}}}, // x + 4(x/2) + 4(x/3).
138       });
139   PWMAFunction mul6 = parsePWMAF(
140       /*numInputs=*/1, /*numOutputs=*/1,
141       {
142           {"(x) : (x - 6*(x floordiv 6) == 0)", {{0, 26, 0}}}, // 26(x/6).
143       });
144   EXPECT_TRUE(mul2AndMul3.isEqual(mul6));
145 
146   PWMAFunction mul6diff = parsePWMAF(
147       /*numInputs=*/1, /*numOutputs=*/1,
148       {
149           {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 52, 0}}}, // 52(x/6).
150       });
151   EXPECT_FALSE(mul2AndMul3.isEqual(mul6diff));
152 
153   PWMAFunction mul5 = parsePWMAF(
154       /*numInputs=*/1, /*numOutputs=*/1,
155       {
156           {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 26, 0}}}, // 26(x/5).
157       });
158   EXPECT_FALSE(mul2AndMul3.isEqual(mul5));
159 }
160 
161 TEST(PWMAFunction, valueAt) {
162   PWMAFunction nonNegPWAF = parsePWMAF(
163       /*numInputs=*/2, /*numOutputs=*/2,
164       {
165           {"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y).
166           {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
167       });
168   EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23));
169   EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23));
170   EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1));
171   EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue());
172 }
173