1 //===- PWMAFunctionTest.cpp - Tests for PWMAFunction ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains tests for PWMAFunction.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "./Utils.h"
14
15 #include "mlir/Analysis/Presburger/PWMAFunction.h"
16 #include "mlir/Analysis/Presburger/PresburgerRelation.h"
17 #include "mlir/IR/MLIRContext.h"
18
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21
22 using namespace mlir;
23 using namespace presburger;
24
25 using testing::ElementsAre;
26
TEST(PWAFunctionTest,isEqual)27 TEST(PWAFunctionTest, isEqual) {
28 // The output expressions are different but it doesn't matter because they are
29 // equal in this domain.
30 PWMAFunction idAtZeros = parsePWMAF(
31 /*numInputs=*/2, /*numOutputs=*/2,
32 {
33 {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
34 {"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
35 {"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
36 });
37 PWMAFunction idAtZeros2 = parsePWMAF(
38 /*numInputs=*/2, /*numOutputs=*/2,
39 {
40 {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 20, 0}}}, // (x, 20y).
41 {"(x, y) : (y - 1 >= 0, x == 0)", {{30, 0, 0}, {0, 1, 0}}}, //(30x, y)
42 {"(x, y) : (-y - 1 > =0, x == 0)", {{30, 0, 0}, {0, 1, 0}}} //(30x, y)
43 });
44 EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2));
45
46 PWMAFunction notIdAtZeros = parsePWMAF(
47 /*numInputs=*/2, /*numOutputs=*/2,
48 {
49 {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
50 {"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
51 {"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
52 });
53 EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros));
54
55 // These match at their intersection but one has a bigger domain.
56 PWMAFunction idNoNegNegQuadrant = parsePWMAF(
57 /*numInputs=*/2, /*numOutputs=*/2,
58 {
59 {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
60 {"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
61 });
62 PWMAFunction idOnlyPosX =
63 parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2,
64 {
65 {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
66 });
67 EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX));
68
69 // Different representations of the same domain.
70 PWMAFunction sumPlusOne = parsePWMAF(
71 /*numInputs=*/2, /*numOutputs=*/1,
72 {
73 {"(x, y) : (x >= 0)", {{1, 1, 1}}}, // x + y + 1.
74 {"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", {{1, 1, 1}}}, // x + y + 1.
75 {"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 1, 1}}} // x + y + 1.
76 });
77 PWMAFunction sumPlusOne2 =
78 parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1,
79 {
80 {"(x, y) : ()", {{1, 1, 1}}}, // x + y + 1.
81 });
82 EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2));
83
84 // Functions with zero input dimensions.
85 PWMAFunction noInputs1 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
86 {
87 {"() : ()", {{1}}}, // 1.
88 });
89 PWMAFunction noInputs2 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
90 {
91 {"() : ()", {{2}}}, // 1.
92 });
93 EXPECT_TRUE(noInputs1.isEqual(noInputs1));
94 EXPECT_FALSE(noInputs1.isEqual(noInputs2));
95
96 // Mismatched dimensionalities.
97 EXPECT_FALSE(noInputs1.isEqual(sumPlusOne));
98 EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne));
99
100 // Divisions.
101 // Domain is only multiples of 6; x = 6k for some k.
102 // x + 4(x/2) + 4(x/3) == 26k.
103 PWMAFunction mul2AndMul3 = parsePWMAF(
104 /*numInputs=*/1, /*numOutputs=*/1,
105 {
106 {"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)",
107 {{1, 4, 4, 0}}}, // x + 4(x/2) + 4(x/3).
108 });
109 PWMAFunction mul6 = parsePWMAF(
110 /*numInputs=*/1, /*numOutputs=*/1,
111 {
112 {"(x) : (x - 6*(x floordiv 6) == 0)", {{0, 26, 0}}}, // 26(x/6).
113 });
114 EXPECT_TRUE(mul2AndMul3.isEqual(mul6));
115
116 PWMAFunction mul6diff = parsePWMAF(
117 /*numInputs=*/1, /*numOutputs=*/1,
118 {
119 {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 52, 0}}}, // 52(x/6).
120 });
121 EXPECT_FALSE(mul2AndMul3.isEqual(mul6diff));
122
123 PWMAFunction mul5 = parsePWMAF(
124 /*numInputs=*/1, /*numOutputs=*/1,
125 {
126 {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 26, 0}}}, // 26(x/5).
127 });
128 EXPECT_FALSE(mul2AndMul3.isEqual(mul5));
129 }
130
TEST(PWMAFunction,valueAt)131 TEST(PWMAFunction, valueAt) {
132 PWMAFunction nonNegPWMAF = parsePWMAF(
133 /*numInputs=*/2, /*numOutputs=*/2,
134 {
135 {"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y).
136 {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
137 });
138 EXPECT_THAT(*nonNegPWMAF.valueAt({2, 3}), ElementsAre(11, 23));
139 EXPECT_THAT(*nonNegPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
140 EXPECT_THAT(*nonNegPWMAF.valueAt({2, -3}), ElementsAre(-1, -1));
141 EXPECT_FALSE(nonNegPWMAF.valueAt({-2, -3}).has_value());
142
143 PWMAFunction divPWMAF = parsePWMAF(
144 /*numInputs=*/2, /*numOutputs=*/2,
145 {
146 {"(x, y) : (x >= 0, x - 2*(x floordiv 2) == 0)",
147 {{0, 2, 1, 3}, {0, 4, 3, 5}}}, // (x, y).
148 {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
149 });
150 EXPECT_THAT(*divPWMAF.valueAt({4, 3}), ElementsAre(11, 23));
151 EXPECT_THAT(*divPWMAF.valueAt({4, -3}), ElementsAre(-1, -1));
152 EXPECT_FALSE(divPWMAF.valueAt({3, 3}).has_value());
153 EXPECT_FALSE(divPWMAF.valueAt({3, -3}).has_value());
154
155 EXPECT_THAT(*divPWMAF.valueAt({-2, 3}), ElementsAre(11, 23));
156 EXPECT_FALSE(divPWMAF.valueAt({-2, -3}).has_value());
157 }
158
TEST(PWMAFunction,removeIdRangeRegressionTest)159 TEST(PWMAFunction, removeIdRangeRegressionTest) {
160 PWMAFunction pwmafA = parsePWMAF(
161 /*numInputs=*/2, /*numOutputs=*/1,
162 {
163 {"(x, y) : (x == 0, y == 0, x - 2*(x floordiv 2) == 0, y - 2*(y "
164 "floordiv 2) == 0)",
165 {{0, 0, 0, 0, 0}}} // (0, 0)
166 });
167 PWMAFunction pwmafB = parsePWMAF(
168 /*numInputs=*/2, /*numOutputs=*/1,
169 {
170 {"(x, y) : (x - 11*y == 0, 11*x - y == 0, x - 2*(x floordiv 2) == 0, "
171 "y - 2*(y floordiv 2) == 0)",
172 {{0, 0, 0, 0, 0}}} // (0, 0)
173 });
174 EXPECT_TRUE(pwmafA.isEqual(pwmafB));
175 }
176
TEST(PWMAFunction,eliminateRedundantLocalIdRegressionTest)177 TEST(PWMAFunction, eliminateRedundantLocalIdRegressionTest) {
178 PWMAFunction pwmafA = parsePWMAF(
179 /*numInputs=*/2, /*numOutputs=*/1,
180 {
181 {"(x, y) : (x - 2*(x floordiv 2) == 0, x - 2*y == 0)",
182 {{0, 1, 0, 0}}} // (0, 0)
183 });
184 PWMAFunction pwmafB = parsePWMAF(
185 /*numInputs=*/2, /*numOutputs=*/1,
186 {
187 {"(x, y) : (x - 2*(x floordiv 2) == 0, x - 2*y == 0)",
188 {{1, -1, 0, 0}}} // (0, 0)
189 });
190 EXPECT_TRUE(pwmafA.isEqual(pwmafB));
191 }
192
TEST(PWMAFunction,unionLexMaxSimple)193 TEST(PWMAFunction, unionLexMaxSimple) {
194 // func2 is better than func1, but func2's domain is empty.
195 {
196 PWMAFunction func1 = parsePWMAF(
197 /*numInputs=*/1, /*numOutputs=*/1,
198 {
199 {"(x) : ()", {{0, 1}}},
200 });
201
202 PWMAFunction func2 = parsePWMAF(
203 /*numInputs=*/1, /*numOutputs=*/1,
204 {
205 {"(x) : (1 == 0)", {{0, 2}}},
206 });
207
208 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(func1));
209 EXPECT_TRUE(func2.unionLexMax(func1).isEqual(func1));
210 }
211
212 // func2 is better than func1 on a subset of func1.
213 {
214 PWMAFunction func1 = parsePWMAF(
215 /*numInputs=*/1, /*numOutputs=*/1,
216 {
217 {"(x) : ()", {{0, 1}}},
218 });
219
220 PWMAFunction func2 = parsePWMAF(
221 /*numInputs=*/1, /*numOutputs=*/1,
222 {
223 {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}},
224 });
225
226 PWMAFunction result = parsePWMAF(
227 /*numInputs=*/1, /*numOutputs=*/1,
228 {
229 {"(x) : (-1 - x >= 0)", {{0, 1}}},
230 {"(x) : (x >= 0, 10 - x >= 0)", {{0, 2}}},
231 {"(x) : (x - 11 >= 0)", {{0, 1}}},
232 });
233
234 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
235 EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
236 }
237
238 // func1 and func2 are defined over the whole domain with different outputs.
239 {
240 PWMAFunction func1 = parsePWMAF(
241 /*numInputs=*/1, /*numOutputs=*/1,
242 {
243 {"(x) : ()", {{1, 0}}},
244 });
245
246 PWMAFunction func2 = parsePWMAF(
247 /*numInputs=*/1, /*numOutputs=*/1,
248 {
249 {"(x) : ()", {{-1, 0}}},
250 });
251
252 PWMAFunction result = parsePWMAF(
253 /*numInputs=*/1, /*numOutputs=*/1,
254 {
255 {"(x) : (x >= 0)", {{1, 0}}},
256 {"(x) : (-1 - x >= 0)", {{-1, 0}}},
257 });
258
259 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
260 EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
261 }
262
263 // func1 and func2 have disjoint domains.
264 {
265 PWMAFunction func1 = parsePWMAF(
266 /*numInputs=*/1, /*numOutputs=*/1,
267 {
268 {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}},
269 {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}},
270 });
271
272 PWMAFunction func2 = parsePWMAF(
273 /*numInputs=*/1, /*numOutputs=*/1,
274 {
275 {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}},
276 {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}},
277 });
278
279 PWMAFunction result = parsePWMAF(
280 /*numInputs=*/1, /*numOutputs=*/1,
281 {
282 {"(x) : (x >= 0, 10 - x >= 0)", {{0, 1}}},
283 {"(x) : (x - 71 >= 0, 80 - x >= 0)", {{0, 1}}},
284 {"(x) : (x - 20 >= 0, 41 - x >= 0)", {{0, 2}}},
285 {"(x) : (x - 101 >= 0, 120 - x >= 0)", {{0, 2}}},
286 });
287
288 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
289 EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
290 }
291 }
292
TEST(PWMAFunction,unionLexMinSimple)293 TEST(PWMAFunction, unionLexMinSimple) {
294 // func2 is better than func1, but func2's domain is empty.
295 {
296 PWMAFunction func1 = parsePWMAF(
297 /*numInputs=*/1, /*numOutputs=*/1,
298 {
299 {"(x) : ()", {{0, -1}}},
300 });
301
302 PWMAFunction func2 = parsePWMAF(
303 /*numInputs=*/1, /*numOutputs=*/1,
304 {
305 {"(x) : (1 == 0)", {{0, -2}}},
306 });
307
308 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(func1));
309 EXPECT_TRUE(func2.unionLexMin(func1).isEqual(func1));
310 }
311
312 // func2 is better than func1 on a subset of func1.
313 {
314 PWMAFunction func1 = parsePWMAF(
315 /*numInputs=*/1, /*numOutputs=*/1,
316 {
317 {"(x) : ()", {{0, -1}}},
318 });
319
320 PWMAFunction func2 = parsePWMAF(
321 /*numInputs=*/1, /*numOutputs=*/1,
322 {
323 {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}},
324 });
325
326 PWMAFunction result = parsePWMAF(
327 /*numInputs=*/1, /*numOutputs=*/1,
328 {
329 {"(x) : (-1 - x >= 0)", {{0, -1}}},
330 {"(x) : (x >= 0, 10 - x >= 0)", {{0, -2}}},
331 {"(x) : (x - 11 >= 0)", {{0, -1}}},
332 });
333
334 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
335 EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
336 }
337
338 // func1 and func2 are defined over the whole domain with different outputs.
339 {
340 PWMAFunction func1 = parsePWMAF(
341 /*numInputs=*/1, /*numOutputs=*/1,
342 {
343 {"(x) : ()", {{-1, 0}}},
344 });
345
346 PWMAFunction func2 = parsePWMAF(
347 /*numInputs=*/1, /*numOutputs=*/1,
348 {
349 {"(x) : ()", {{1, 0}}},
350 });
351
352 PWMAFunction result = parsePWMAF(
353 /*numInputs=*/1, /*numOutputs=*/1,
354 {
355 {"(x) : (x >= 0)", {{-1, 0}}},
356 {"(x) : (-1 - x >= 0)", {{1, 0}}},
357 });
358
359 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
360 EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
361 }
362 }
363
TEST(PWMAFunction,unionLexMaxComplex)364 TEST(PWMAFunction, unionLexMaxComplex) {
365 // Union of function containing 4 different pieces of output.
366 //
367 // x >= 21 --> func1 (func2 not defined)
368 // x <= 0 --> func2 (func1 not defined)
369 // 10 <= x <= 20, y > 0 --> func1 (x + y > x - y for y > 0)
370 // 10 <= x <= 20, y <= 0 --> func2 (x + y <= x - y for y <= 0)
371 {
372 PWMAFunction func1 = parsePWMAF(
373 /*numInputs=*/2, /*numOutputs=*/1,
374 {
375 {"(x, y) : (x >= 10)", {{1, 1, 0}}},
376 });
377
378 PWMAFunction func2 = parsePWMAF(
379 /*numInputs=*/2, /*numOutputs=*/1,
380 {
381 {"(x, y) : (x <= 20)", {{1, -1, 0}}},
382 });
383
384 PWMAFunction result = parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1,
385 {{"(x, y) : (x >= 10, x <= 20, y >= 1)",
386 {
387 {1, 1, 0},
388 }},
389 {"(x, y) : (x >= 21)",
390 {
391 {1, 1, 0},
392 }},
393 {"(x, y) : (x <= 9)",
394 {
395 {1, -1, 0},
396 }},
397 {"(x, y) : (x >= 10, x <= 20, y <= 0)",
398 {
399 {1, -1, 0},
400 }}});
401
402 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
403 }
404
405 // Functions with more than one output, with contribution from both functions.
406 //
407 // If y >= 1, func1 is better because in the first output,
408 // x + y (func1) > x (func2), when y >= 1
409 //
410 // If y == 0, the first output is same for both functions, so we look at the
411 // second output. -2x + 4 (func1) > 2x - 2 (func2) when 0 <= x <= 1, so we
412 // take func1 for this domain and func2 for the remaining.
413 {
414 PWMAFunction func1 = parsePWMAF(
415 /*numInputs=*/2, /*numOutputs=*/2,
416 {
417 {"(x, y) : (x >= 0, y >= 0)", {{1, 1, 0}, {-2, 0, 4}}},
418 });
419
420 PWMAFunction func2 = parsePWMAF(
421 /*numInputs=*/2, /*numOutputs=*/2,
422 {
423 {"(x, y) : (x >= 0, y >= 0)", {{1, 0, 0}, {2, 0, -2}}},
424 });
425
426 PWMAFunction result = parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2,
427 {{"(x, y) : (x >= 0, y >= 1)",
428 {
429 {1, 1, 0},
430 {-2, 0, 4},
431 }},
432 {"(x, y) : (x >= 0, x <= 1, y == 0)",
433 {
434 {1, 1, 0},
435 {-2, 0, 4},
436 }},
437 {"(x, y) : (x >= 2, y == 0)",
438 {
439 {1, 0, 0},
440 {2, 0, -2},
441 }}});
442
443 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
444 EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
445 }
446
447 // Function with three boolean variables `a, b, c` used to control which
448 // output will be taken lexicographically.
449 //
450 // a == 1 --> Take func2
451 // a == 0, b == 1 --> Take func1
452 // a == 0, b == 0, c == 1 --> Take func2
453 {
454 PWMAFunction func1 = parsePWMAF(
455 /*numInputs=*/3, /*numOutputs=*/3,
456 {
457 {"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c "
458 ">= 0, 1 - c >= 0)",
459 {{0, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 0}}},
460 });
461
462 PWMAFunction func2 = parsePWMAF(
463 /*numInputs=*/3, /*numOutputs=*/3,
464 {
465 {"(a, b, c) : (a >= 0, 1 - a >= 0, b >= 0, 1 - b >= 0, c >= 0, 1 - "
466 "c >= 0)",
467 {{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}},
468 });
469
470 PWMAFunction result = parsePWMAF(
471 /*numInputs=*/3, /*numOutputs=*/3,
472 {
473 {"(a, b, c) : (a - 1 == 0, b >= 0, 1 - b >= 0, c >= 0, 1 - c >= 0)",
474 {{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}},
475 {"(a, b, c) : (a == 0, b - 1 == 0, c >= 0, 1 - c >= 0)",
476 {{0, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 0, 0}}},
477 {"(a, b, c) : (a == 0, b == 0, c >= 0, 1 - c >= 0)",
478 {{1, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 1, 0}}},
479 });
480
481 EXPECT_TRUE(func1.unionLexMax(func2).isEqual(result));
482 EXPECT_TRUE(func2.unionLexMax(func1).isEqual(result));
483 }
484 }
485
TEST(PWMAFunction,unionLexMinComplex)486 TEST(PWMAFunction, unionLexMinComplex) {
487 // Regression test checking if lexicographic tiebreak produces disjoint
488 // domains.
489 //
490 // If x == 1, func1 is better since in the first output,
491 // -x (func1) is < 0 (func2) when x == 1.
492 //
493 // If x == 0, func1 and func2 both have the same first output. So we take a
494 // look at the second output. func2 is better since in the second output,
495 // y - 1 (func2) is < y (func1).
496 PWMAFunction func1 = parsePWMAF(
497 /*numInputs=*/2, /*numOutputs=*/2,
498 {
499 {"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)",
500 {{-1, 0, 0}, {0, 1, 0}}},
501 });
502
503 PWMAFunction func2 = parsePWMAF(
504 /*numInputs=*/2, /*numOutputs=*/2,
505 {
506 {"(x, y) : (x >= 0, x <= 1, y >= 0, y <= 1)",
507 {{0, 0, 0}, {0, 1, -1}}},
508 });
509
510 PWMAFunction result = parsePWMAF(
511 /*numInputs=*/2, /*numOutputs=*/2,
512 {
513 {"(x, y) : (x == 1, y >= 0, y <= 1)", {{-1, 0, 0}, {0, 1, 0}}},
514 {"(x, y) : (x == 0, y >= 0, y <= 1)", {{0, 0, 0}, {0, 1, -1}}},
515 });
516
517 EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
518 EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
519 }
520