1 //===- PWMAFunctionTest.cpp - Tests for PWMAFunction ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains tests for PWMAFunction. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "./Utils.h" 14 15 #include "mlir/Analysis/Presburger/PWMAFunction.h" 16 #include "mlir/Analysis/Presburger/PresburgerRelation.h" 17 #include "mlir/IR/MLIRContext.h" 18 19 #include <gmock/gmock.h> 20 #include <gtest/gtest.h> 21 22 using namespace mlir; 23 using namespace presburger; 24 25 using testing::ElementsAre; 26 27 static Matrix makeMatrix(unsigned numRow, unsigned numColumns, 28 ArrayRef<SmallVector<int64_t, 8>> matrix) { 29 Matrix results(numRow, numColumns); 30 assert(matrix.size() == numRow); 31 for (unsigned i = 0; i < numRow; ++i) { 32 assert(matrix[i].size() == numColumns && 33 "Output expression has incorrect dimensionality!"); 34 for (unsigned j = 0; j < numColumns; ++j) 35 results(i, j) = matrix[i][j]; 36 } 37 return results; 38 } 39 40 /// Construct a PWMAFunction given the dimensionalities and an array describing 41 /// the list of pieces. Each piece is given by a string describing the domain 42 /// and a 2D array that represents the output. 43 static PWMAFunction parsePWMAF( 44 unsigned numInputs, unsigned numOutputs, 45 ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>> 46 data, 47 unsigned numSymbols = 0) { 48 PWMAFunction result(numInputs - numSymbols, numSymbols, numOutputs); 49 for (const auto &pair : data) { 50 IntegerPolyhedron domain = parsePoly(pair.first); 51 result.addPiece( 52 domain, makeMatrix(numOutputs, domain.getNumIds() + 1, pair.second)); 53 } 54 return result; 55 } 56 57 TEST(PWAFunctionTest, isEqual) { 58 // The output expressions are different but it doesn't matter because they are 59 // equal in this domain. 60 PWMAFunction idAtZeros = parsePWMAF( 61 /*numInputs=*/2, /*numOutputs=*/2, 62 { 63 {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). 64 {"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). 65 {"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y). 66 }); 67 PWMAFunction idAtZeros2 = parsePWMAF( 68 /*numInputs=*/2, /*numOutputs=*/2, 69 { 70 {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 20, 0}}}, // (x, 20y). 71 {"(x, y) : (y - 1 >= 0, x == 0)", {{30, 0, 0}, {0, 1, 0}}}, //(30x, y) 72 {"(x, y) : (-y - 1 > =0, x == 0)", {{30, 0, 0}, {0, 1, 0}}} //(30x, y) 73 }); 74 EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2)); 75 76 PWMAFunction notIdAtZeros = parsePWMAF( 77 /*numInputs=*/2, /*numOutputs=*/2, 78 { 79 {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). 80 {"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y) 81 {"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y) 82 }); 83 EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros)); 84 85 // These match at their intersection but one has a bigger domain. 86 PWMAFunction idNoNegNegQuadrant = parsePWMAF( 87 /*numInputs=*/2, /*numOutputs=*/2, 88 { 89 {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). 90 {"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y). 91 }); 92 PWMAFunction idOnlyPosX = 93 parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2, 94 { 95 {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y). 96 }); 97 EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX)); 98 99 // Different representations of the same domain. 100 PWMAFunction sumPlusOne = parsePWMAF( 101 /*numInputs=*/2, /*numOutputs=*/1, 102 { 103 {"(x, y) : (x >= 0)", {{1, 1, 1}}}, // x + y + 1. 104 {"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", {{1, 1, 1}}}, // x + y + 1. 105 {"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 1, 1}}} // x + y + 1. 106 }); 107 PWMAFunction sumPlusOne2 = 108 parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1, 109 { 110 {"(x, y) : ()", {{1, 1, 1}}}, // x + y + 1. 111 }); 112 EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2)); 113 114 // Functions with zero input dimensions. 115 PWMAFunction noInputs1 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1, 116 { 117 {"() : ()", {{1}}}, // 1. 118 }); 119 PWMAFunction noInputs2 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1, 120 { 121 {"() : ()", {{2}}}, // 1. 122 }); 123 EXPECT_TRUE(noInputs1.isEqual(noInputs1)); 124 EXPECT_FALSE(noInputs1.isEqual(noInputs2)); 125 126 // Mismatched dimensionalities. 127 EXPECT_FALSE(noInputs1.isEqual(sumPlusOne)); 128 EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne)); 129 130 // Divisions. 131 // Domain is only multiples of 6; x = 6k for some k. 132 // x + 4(x/2) + 4(x/3) == 26k. 133 PWMAFunction mul2AndMul3 = parsePWMAF( 134 /*numInputs=*/1, /*numOutputs=*/1, 135 { 136 {"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)", 137 {{1, 4, 4, 0}}}, // x + 4(x/2) + 4(x/3). 138 }); 139 PWMAFunction mul6 = parsePWMAF( 140 /*numInputs=*/1, /*numOutputs=*/1, 141 { 142 {"(x) : (x - 6*(x floordiv 6) == 0)", {{0, 26, 0}}}, // 26(x/6). 143 }); 144 EXPECT_TRUE(mul2AndMul3.isEqual(mul6)); 145 146 PWMAFunction mul6diff = parsePWMAF( 147 /*numInputs=*/1, /*numOutputs=*/1, 148 { 149 {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 52, 0}}}, // 52(x/6). 150 }); 151 EXPECT_FALSE(mul2AndMul3.isEqual(mul6diff)); 152 153 PWMAFunction mul5 = parsePWMAF( 154 /*numInputs=*/1, /*numOutputs=*/1, 155 { 156 {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 26, 0}}}, // 26(x/5). 157 }); 158 EXPECT_FALSE(mul2AndMul3.isEqual(mul5)); 159 } 160 161 TEST(PWMAFunction, valueAt) { 162 PWMAFunction nonNegPWAF = parsePWMAF( 163 /*numInputs=*/2, /*numOutputs=*/2, 164 { 165 {"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y). 166 {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y) 167 }); 168 EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23)); 169 EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23)); 170 EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1)); 171 EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue()); 172 } 173