1 //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Analysis/Presburger/PWMAFunction.h" 10 #include "mlir/Analysis/Presburger/Simplex.h" 11 12 using namespace mlir; 13 using namespace presburger; 14 15 // Return the result of subtracting the two given vectors pointwise. 16 // The vectors must be of the same size. 17 // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5]. 18 static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA, 19 ArrayRef<int64_t> vecB) { 20 assert(vecA.size() == vecB.size() && 21 "Cannot subtract vectors of differing lengths!"); 22 SmallVector<int64_t, 8> result; 23 result.reserve(vecA.size()); 24 for (unsigned i = 0, e = vecA.size(); i < e; ++i) 25 result.push_back(vecA[i] - vecB[i]); 26 return result; 27 } 28 29 PresburgerSet PWMAFunction::getDomain() const { 30 PresburgerSet domain = PresburgerSet::getEmpty(getSpace()); 31 for (const MultiAffineFunction &piece : pieces) 32 domain.unionInPlace(piece.getDomain()); 33 return domain; 34 } 35 36 Optional<SmallVector<int64_t, 8>> 37 MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const { 38 assert(point.size() == domainSet.getNumDimAndSymbolVars() && 39 "Point has incorrect dimensionality!"); 40 41 Optional<SmallVector<int64_t, 8>> maybeLocalValues = 42 getDomain().containsPointNoLocal(point); 43 if (!maybeLocalValues) 44 return {}; 45 46 // The point lies in the domain, so we need to compute the output value. 47 SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)}; 48 // The given point didn't include the values of locals which the output is a 49 // function of; we have computed one possible set of values and use them 50 // here. The function is not allowed to have local vars that take more than 51 // one possible value. 52 pointHomogenous.append(*maybeLocalValues); 53 // The matrix `output` has an affine expression in the ith row, corresponding 54 // to the expression for the ith value in the output vector. The last column 55 // of the matrix contains the constant term. Let v be the input point with 56 // a 1 appended at the end. We can see that output * v gives the desired 57 // output vector. 58 pointHomogenous.emplace_back(1); 59 SmallVector<int64_t, 8> result = 60 output.postMultiplyWithColumn(pointHomogenous); 61 assert(result.size() == getNumOutputs()); 62 return result; 63 } 64 65 Optional<SmallVector<int64_t, 8>> 66 PWMAFunction::valueAt(ArrayRef<int64_t> point) const { 67 assert(point.size() == getNumInputs() && 68 "Point has incorrect dimensionality!"); 69 for (const MultiAffineFunction &piece : pieces) 70 if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point)) 71 return output; 72 return {}; 73 } 74 75 void MultiAffineFunction::print(raw_ostream &os) const { 76 os << "Domain:"; 77 domainSet.print(os); 78 os << "Output:\n"; 79 output.print(os); 80 os << "\n"; 81 } 82 83 void MultiAffineFunction::dump() const { print(llvm::errs()); } 84 85 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const { 86 return getDomainSpace().isCompatible(other.getDomainSpace()) && 87 getDomain().isEqual(other.getDomain()) && 88 isEqualWhereDomainsOverlap(other); 89 } 90 91 unsigned MultiAffineFunction::insertVar(VarKind kind, unsigned pos, 92 unsigned num) { 93 assert(kind != VarKind::Domain && "Domain has to be zero in a set"); 94 unsigned absolutePos = domainSet.getVarKindOffset(kind) + pos; 95 output.insertColumns(absolutePos, num); 96 return domainSet.insertVar(kind, pos, num); 97 } 98 99 void MultiAffineFunction::removeVarRange(VarKind kind, unsigned varStart, 100 unsigned varLimit) { 101 output.removeColumns(varStart + domainSet.getVarKindOffset(kind), 102 varLimit - varStart); 103 domainSet.removeVarRange(kind, varStart, varLimit); 104 } 105 106 void MultiAffineFunction::truncateOutput(unsigned count) { 107 assert(count <= output.getNumRows()); 108 output.resizeVertically(count); 109 } 110 111 void PWMAFunction::truncateOutput(unsigned count) { 112 assert(count <= numOutputs); 113 for (MultiAffineFunction &piece : pieces) 114 piece.truncateOutput(count); 115 numOutputs = count; 116 } 117 118 void MultiAffineFunction::mergeLocalVars(MultiAffineFunction &other) { 119 // Merge output local vars of both functions without using division 120 // information i.e. append local vars of `other` to `this` and insert 121 // local vars of `this` to `other` at the start of it's local vars. 122 output.insertColumns(domainSet.getVarKindEnd(VarKind::Local), 123 other.domainSet.getNumLocalVars()); 124 other.output.insertColumns(other.domainSet.getVarKindOffset(VarKind::Local), 125 domainSet.getNumLocalVars()); 126 127 auto merge = [this, &other](unsigned i, unsigned j) -> bool { 128 // Merge local at position j into local at position i in function domain. 129 domainSet.eliminateRedundantLocalVar(i, j); 130 other.domainSet.eliminateRedundantLocalVar(i, j); 131 132 unsigned localOffset = domainSet.getVarKindOffset(VarKind::Local); 133 134 // Merge local at position j into local at position i in output domain. 135 output.addToColumn(localOffset + j, localOffset + i, 1); 136 output.removeColumn(localOffset + j); 137 other.output.addToColumn(localOffset + j, localOffset + i, 1); 138 other.output.removeColumn(localOffset + j); 139 140 return true; 141 }; 142 143 presburger::mergeLocalVars(domainSet, other.domainSet, merge); 144 } 145 146 bool MultiAffineFunction::isEqualWhereDomainsOverlap( 147 MultiAffineFunction other) const { 148 if (!getDomainSpace().isCompatible(other.getDomainSpace())) 149 return false; 150 151 // `commonFunc` has the same output as `this`. 152 MultiAffineFunction commonFunc = *this; 153 // After this merge, `commonFunc` and `other` have the same local vars; they 154 // are merged. 155 commonFunc.mergeLocalVars(other); 156 // After this, the domain of `commonFunc` will be the intersection of the 157 // domains of `this` and `other`. 158 commonFunc.domainSet.append(other.domainSet); 159 160 // `commonDomainMatching` contains the subset of the common domain 161 // where the outputs of `this` and `other` match. 162 // 163 // We want to add constraints equating the outputs of `this` and `other`. 164 // However, `this` may have difference local vars from `other`, whereas we 165 // need both to have the same locals. Accordingly, we use `commonFunc.output` 166 // in place of `this->output`, since `commonFunc` has the same output but also 167 // has its locals merged. 168 IntegerPolyhedron commonDomainMatching = commonFunc.getDomain(); 169 for (unsigned row = 0, e = getNumOutputs(); row < e; ++row) 170 commonDomainMatching.addEquality( 171 subtract(commonFunc.output.getRow(row), other.output.getRow(row))); 172 173 // If the whole common domain is a subset of commonDomainMatching, then they 174 // are equal and the two functions match on the whole common domain. 175 return commonFunc.getDomain().isSubsetOf(commonDomainMatching); 176 } 177 178 /// Two PWMAFunctions are equal if they have the same dimensionalities, 179 /// the same domain, and take the same value at every point in the domain. 180 bool PWMAFunction::isEqual(const PWMAFunction &other) const { 181 if (!space.isCompatible(other.space)) 182 return false; 183 184 if (!this->getDomain().isEqual(other.getDomain())) 185 return false; 186 187 // Check if, whenever the domains of a piece of `this` and a piece of `other` 188 // overlap, they take the same output value. If `this` and `other` have the 189 // same domain (checked above), then this check passes iff the two functions 190 // have the same output at every point in the domain. 191 for (const MultiAffineFunction &aPiece : this->pieces) 192 for (const MultiAffineFunction &bPiece : other.pieces) 193 if (!aPiece.isEqualWhereDomainsOverlap(bPiece)) 194 return false; 195 return true; 196 } 197 198 void PWMAFunction::addPiece(const MultiAffineFunction &piece) { 199 assert(space.isCompatible(piece.getDomainSpace()) && 200 "Piece to be added is not compatible with this PWMAFunction!"); 201 assert(piece.isConsistent() && "Piece is internally inconsistent!"); 202 assert(this->getDomain() 203 .intersect(PresburgerSet(piece.getDomain())) 204 .isIntegerEmpty() && 205 "New piece's domain overlaps with that of existing pieces!"); 206 pieces.push_back(piece); 207 } 208 209 void PWMAFunction::addPiece(const IntegerPolyhedron &domain, 210 const Matrix &output) { 211 addPiece(MultiAffineFunction(domain, output)); 212 } 213 214 void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) { 215 for (const IntegerRelation &newDom : domain.getAllDisjuncts()) 216 addPiece(IntegerPolyhedron(newDom), output); 217 } 218 219 void PWMAFunction::print(raw_ostream &os) const { 220 os << pieces.size() << " pieces:\n"; 221 for (const MultiAffineFunction &piece : pieces) 222 piece.print(os); 223 } 224 225 void PWMAFunction::dump() const { print(llvm::errs()); } 226 227 PWMAFunction PWMAFunction::unionFunction( 228 const PWMAFunction &func, 229 llvm::function_ref<PresburgerSet(MultiAffineFunction maf1, 230 MultiAffineFunction maf2)> 231 tiebreak) const { 232 assert(getNumOutputs() == func.getNumOutputs() && 233 "Number of outputs of functions should be same."); 234 assert(getSpace().isCompatible(func.getSpace()) && 235 "Space is not compatible."); 236 237 // The algorithm used here is as follows: 238 // - Add the output of funcB for the part of the domain where both funcA and 239 // funcB are defined, and `tiebreak` chooses the output of funcB. 240 // - Add the output of funcA, where funcB is not defined or `tiebreak` chooses 241 // funcA over funcB. 242 // - Add the output of funcB, where funcA is not defined. 243 244 // Add parts of the common domain where funcB's output is used. Also 245 // add all the parts where funcA's output is used, both common and non-common. 246 PWMAFunction result(getSpace(), getNumOutputs()); 247 for (const MultiAffineFunction &funcA : pieces) { 248 PresburgerSet dom(funcA.getDomain()); 249 for (const MultiAffineFunction &funcB : func.pieces) { 250 PresburgerSet better = tiebreak(funcB, funcA); 251 // Add the output of funcB, where it is better than output of funcA. 252 // The disjuncts in "better" will be disjoint as tiebreak should gurantee 253 // that. 254 result.addPiece(better, funcB.getOutputMatrix()); 255 dom = dom.subtract(better); 256 } 257 // Add output of funcA, where it is better than funcB, or funcB is not 258 // defined. 259 // 260 // `dom` here is guranteed to be disjoint from already added pieces 261 // because because the pieces added before are either: 262 // - Subsets of the domain of other MAFs in `this`, which are guranteed 263 // to be disjoint from `dom`, or 264 // - They are one of the pieces added for `funcB`, and we have been 265 // subtracting all such pieces from `dom`, so `dom` is disjoint from those 266 // pieces as well. 267 result.addPiece(dom, funcA.getOutputMatrix()); 268 } 269 270 // Add parts of funcB which are not shared with funcA. 271 PresburgerSet dom = getDomain(); 272 for (const MultiAffineFunction &funcB : func.pieces) 273 result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix()); 274 275 return result; 276 } 277 278 /// A tiebreak function which breaks ties by comparing the outputs 279 /// lexicographically. If `lexMin` is true, then the ties are broken by 280 /// taking the lexicographically smaller output and otherwise, by taking the 281 /// lexicographically larger output. 282 template <bool lexMin> 283 static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA, 284 const MultiAffineFunction &mafB) { 285 // TODO: Support local variables here. 286 assert(mafA.getDomainSpace().isCompatible(mafB.getDomainSpace()) && 287 "Domain spaces should be compatible."); 288 assert(mafA.getNumOutputs() == mafB.getNumOutputs() && 289 "Number of outputs of both functions should be same."); 290 assert(mafA.getDomain().getNumLocalVars() == 0 && 291 "Local variables are not supported yet."); 292 293 PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals(); 294 const PresburgerSpace &space = mafA.getDomain().getSpace(); 295 296 // We first create the set `result`, corresponding to the set where output 297 // of mafA is lexicographically larger/smaller than mafB. This is done by 298 // creating a PresburgerSet with the following constraints: 299 // 300 // (outA[0] > outB[0]) U 301 // (outA[0] = outB[0], outA[1] > outA[1]) U 302 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U 303 // ... 304 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1]) 305 // 306 // where `n` is the number of outputs. 307 // If `lexMin` is set, the complement inequality is used: 308 // 309 // (outA[0] < outB[0]) U 310 // (outA[0] = outB[0], outA[1] < outA[1]) U 311 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U 312 // ... 313 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1]) 314 PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace); 315 IntegerPolyhedron levelSet(/*numReservedInequalities=*/1, 316 /*numReservedEqualities=*/mafA.getNumOutputs(), 317 /*numReservedCols=*/space.getNumVars() + 1, space); 318 for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) { 319 320 // Create the expression `outA - outB` for this level. 321 SmallVector<int64_t, 8> subExpr = 322 subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level)); 323 324 if (lexMin) { 325 // For lexMin, we add an upper bound of -1: 326 // outA - outB <= -1 327 // outA <= outB - 1 328 // outA < outB 329 levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, -1); 330 } else { 331 // For lexMax, we add a lower bound of 1: 332 // outA - outB >= 1 333 // outA > outB + 1 334 // outA > outB 335 levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, 1); 336 } 337 338 // Union the set with the result. 339 result.unionInPlace(levelSet); 340 // There is only 1 inequality in `levelSet`, so the index is always 0. 341 levelSet.removeInequality(0); 342 // Add equality `outA - outB == 0` for this level for next iteration. 343 levelSet.addEquality(subExpr); 344 } 345 346 // We then intersect `result` with the domain of mafA and mafB, to only 347 // tiebreak on the domain where both are defined. 348 result = result.intersect(PresburgerSet(mafA.getDomain())) 349 .intersect(PresburgerSet(mafB.getDomain())); 350 351 return result; 352 } 353 354 PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) { 355 return unionFunction(func, tiebreakLex</*lexMin=*/true>); 356 } 357 358 PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) { 359 return unionFunction(func, tiebreakLex</*lexMin=*/false>); 360 } 361