1 //===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Analysis/Presburger/PWMAFunction.h"
10 #include "mlir/Analysis/Presburger/Simplex.h"
11
12 using namespace mlir;
13 using namespace presburger;
14
15 // Return the result of subtracting the two given vectors pointwise.
16 // The vectors must be of the same size.
17 // e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
subtract(ArrayRef<int64_t> vecA,ArrayRef<int64_t> vecB)18 static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
19 ArrayRef<int64_t> vecB) {
20 assert(vecA.size() == vecB.size() &&
21 "Cannot subtract vectors of differing lengths!");
22 SmallVector<int64_t, 8> result;
23 result.reserve(vecA.size());
24 for (unsigned i = 0, e = vecA.size(); i < e; ++i)
25 result.push_back(vecA[i] - vecB[i]);
26 return result;
27 }
28
getDomain() const29 PresburgerSet PWMAFunction::getDomain() const {
30 PresburgerSet domain = PresburgerSet::getEmpty(getSpace());
31 for (const MultiAffineFunction &piece : pieces)
32 domain.unionInPlace(piece.getDomain());
33 return domain;
34 }
35
36 Optional<SmallVector<int64_t, 8>>
valueAt(ArrayRef<int64_t> point) const37 MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
38 assert(point.size() == domainSet.getNumDimAndSymbolVars() &&
39 "Point has incorrect dimensionality!");
40
41 Optional<SmallVector<int64_t, 8>> maybeLocalValues =
42 getDomain().containsPointNoLocal(point);
43 if (!maybeLocalValues)
44 return {};
45
46 // The point lies in the domain, so we need to compute the output value.
47 SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
48 // The given point didn't include the values of locals which the output is a
49 // function of; we have computed one possible set of values and use them
50 // here. The function is not allowed to have local vars that take more than
51 // one possible value.
52 pointHomogenous.append(*maybeLocalValues);
53 // The matrix `output` has an affine expression in the ith row, corresponding
54 // to the expression for the ith value in the output vector. The last column
55 // of the matrix contains the constant term. Let v be the input point with
56 // a 1 appended at the end. We can see that output * v gives the desired
57 // output vector.
58 pointHomogenous.emplace_back(1);
59 SmallVector<int64_t, 8> result =
60 output.postMultiplyWithColumn(pointHomogenous);
61 assert(result.size() == getNumOutputs());
62 return result;
63 }
64
65 Optional<SmallVector<int64_t, 8>>
valueAt(ArrayRef<int64_t> point) const66 PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
67 assert(point.size() == getNumInputs() &&
68 "Point has incorrect dimensionality!");
69 for (const MultiAffineFunction &piece : pieces)
70 if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
71 return output;
72 return {};
73 }
74
print(raw_ostream & os) const75 void MultiAffineFunction::print(raw_ostream &os) const {
76 os << "Domain:";
77 domainSet.print(os);
78 os << "Output:\n";
79 output.print(os);
80 os << "\n";
81 }
82
dump() const83 void MultiAffineFunction::dump() const { print(llvm::errs()); }
84
isEqual(const MultiAffineFunction & other) const85 bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
86 return getDomainSpace().isCompatible(other.getDomainSpace()) &&
87 getDomain().isEqual(other.getDomain()) &&
88 isEqualWhereDomainsOverlap(other);
89 }
90
insertVar(VarKind kind,unsigned pos,unsigned num)91 unsigned MultiAffineFunction::insertVar(VarKind kind, unsigned pos,
92 unsigned num) {
93 assert(kind != VarKind::Domain && "Domain has to be zero in a set");
94 unsigned absolutePos = domainSet.getVarKindOffset(kind) + pos;
95 output.insertColumns(absolutePos, num);
96 return domainSet.insertVar(kind, pos, num);
97 }
98
removeVarRange(VarKind kind,unsigned varStart,unsigned varLimit)99 void MultiAffineFunction::removeVarRange(VarKind kind, unsigned varStart,
100 unsigned varLimit) {
101 output.removeColumns(varStart + domainSet.getVarKindOffset(kind),
102 varLimit - varStart);
103 domainSet.removeVarRange(kind, varStart, varLimit);
104 }
105
truncateOutput(unsigned count)106 void MultiAffineFunction::truncateOutput(unsigned count) {
107 assert(count <= output.getNumRows());
108 output.resizeVertically(count);
109 }
110
truncateOutput(unsigned count)111 void PWMAFunction::truncateOutput(unsigned count) {
112 assert(count <= numOutputs);
113 for (MultiAffineFunction &piece : pieces)
114 piece.truncateOutput(count);
115 numOutputs = count;
116 }
117
mergeLocalVars(MultiAffineFunction & other)118 void MultiAffineFunction::mergeLocalVars(MultiAffineFunction &other) {
119 // Merge output local vars of both functions without using division
120 // information i.e. append local vars of `other` to `this` and insert
121 // local vars of `this` to `other` at the start of it's local vars.
122 output.insertColumns(domainSet.getVarKindEnd(VarKind::Local),
123 other.domainSet.getNumLocalVars());
124 other.output.insertColumns(other.domainSet.getVarKindOffset(VarKind::Local),
125 domainSet.getNumLocalVars());
126
127 auto merge = [this, &other](unsigned i, unsigned j) -> bool {
128 // Merge local at position j into local at position i in function domain.
129 domainSet.eliminateRedundantLocalVar(i, j);
130 other.domainSet.eliminateRedundantLocalVar(i, j);
131
132 unsigned localOffset = domainSet.getVarKindOffset(VarKind::Local);
133
134 // Merge local at position j into local at position i in output domain.
135 output.addToColumn(localOffset + j, localOffset + i, 1);
136 output.removeColumn(localOffset + j);
137 other.output.addToColumn(localOffset + j, localOffset + i, 1);
138 other.output.removeColumn(localOffset + j);
139
140 return true;
141 };
142
143 presburger::mergeLocalVars(domainSet, other.domainSet, merge);
144 }
145
isEqualWhereDomainsOverlap(MultiAffineFunction other) const146 bool MultiAffineFunction::isEqualWhereDomainsOverlap(
147 MultiAffineFunction other) const {
148 if (!getDomainSpace().isCompatible(other.getDomainSpace()))
149 return false;
150
151 // `commonFunc` has the same output as `this`.
152 MultiAffineFunction commonFunc = *this;
153 // After this merge, `commonFunc` and `other` have the same local vars; they
154 // are merged.
155 commonFunc.mergeLocalVars(other);
156 // After this, the domain of `commonFunc` will be the intersection of the
157 // domains of `this` and `other`.
158 commonFunc.domainSet.append(other.domainSet);
159
160 // `commonDomainMatching` contains the subset of the common domain
161 // where the outputs of `this` and `other` match.
162 //
163 // We want to add constraints equating the outputs of `this` and `other`.
164 // However, `this` may have difference local vars from `other`, whereas we
165 // need both to have the same locals. Accordingly, we use `commonFunc.output`
166 // in place of `this->output`, since `commonFunc` has the same output but also
167 // has its locals merged.
168 IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
169 for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
170 commonDomainMatching.addEquality(
171 subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
172
173 // If the whole common domain is a subset of commonDomainMatching, then they
174 // are equal and the two functions match on the whole common domain.
175 return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
176 }
177
178 /// Two PWMAFunctions are equal if they have the same dimensionalities,
179 /// the same domain, and take the same value at every point in the domain.
isEqual(const PWMAFunction & other) const180 bool PWMAFunction::isEqual(const PWMAFunction &other) const {
181 if (!space.isCompatible(other.space))
182 return false;
183
184 if (!this->getDomain().isEqual(other.getDomain()))
185 return false;
186
187 // Check if, whenever the domains of a piece of `this` and a piece of `other`
188 // overlap, they take the same output value. If `this` and `other` have the
189 // same domain (checked above), then this check passes iff the two functions
190 // have the same output at every point in the domain.
191 for (const MultiAffineFunction &aPiece : this->pieces)
192 for (const MultiAffineFunction &bPiece : other.pieces)
193 if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
194 return false;
195 return true;
196 }
197
addPiece(const MultiAffineFunction & piece)198 void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
199 assert(space.isCompatible(piece.getDomainSpace()) &&
200 "Piece to be added is not compatible with this PWMAFunction!");
201 assert(piece.isConsistent() && "Piece is internally inconsistent!");
202 assert(this->getDomain()
203 .intersect(PresburgerSet(piece.getDomain()))
204 .isIntegerEmpty() &&
205 "New piece's domain overlaps with that of existing pieces!");
206 pieces.push_back(piece);
207 }
208
addPiece(const IntegerPolyhedron & domain,const Matrix & output)209 void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
210 const Matrix &output) {
211 addPiece(MultiAffineFunction(domain, output));
212 }
213
addPiece(const PresburgerSet & domain,const Matrix & output)214 void PWMAFunction::addPiece(const PresburgerSet &domain, const Matrix &output) {
215 for (const IntegerRelation &newDom : domain.getAllDisjuncts())
216 addPiece(IntegerPolyhedron(newDom), output);
217 }
218
print(raw_ostream & os) const219 void PWMAFunction::print(raw_ostream &os) const {
220 os << pieces.size() << " pieces:\n";
221 for (const MultiAffineFunction &piece : pieces)
222 piece.print(os);
223 }
224
dump() const225 void PWMAFunction::dump() const { print(llvm::errs()); }
226
unionFunction(const PWMAFunction & func,llvm::function_ref<PresburgerSet (MultiAffineFunction maf1,MultiAffineFunction maf2)> tiebreak) const227 PWMAFunction PWMAFunction::unionFunction(
228 const PWMAFunction &func,
229 llvm::function_ref<PresburgerSet(MultiAffineFunction maf1,
230 MultiAffineFunction maf2)>
231 tiebreak) const {
232 assert(getNumOutputs() == func.getNumOutputs() &&
233 "Number of outputs of functions should be same.");
234 assert(getSpace().isCompatible(func.getSpace()) &&
235 "Space is not compatible.");
236
237 // The algorithm used here is as follows:
238 // - Add the output of funcB for the part of the domain where both funcA and
239 // funcB are defined, and `tiebreak` chooses the output of funcB.
240 // - Add the output of funcA, where funcB is not defined or `tiebreak` chooses
241 // funcA over funcB.
242 // - Add the output of funcB, where funcA is not defined.
243
244 // Add parts of the common domain where funcB's output is used. Also
245 // add all the parts where funcA's output is used, both common and non-common.
246 PWMAFunction result(getSpace(), getNumOutputs());
247 for (const MultiAffineFunction &funcA : pieces) {
248 PresburgerSet dom(funcA.getDomain());
249 for (const MultiAffineFunction &funcB : func.pieces) {
250 PresburgerSet better = tiebreak(funcB, funcA);
251 // Add the output of funcB, where it is better than output of funcA.
252 // The disjuncts in "better" will be disjoint as tiebreak should gurantee
253 // that.
254 result.addPiece(better, funcB.getOutputMatrix());
255 dom = dom.subtract(better);
256 }
257 // Add output of funcA, where it is better than funcB, or funcB is not
258 // defined.
259 //
260 // `dom` here is guranteed to be disjoint from already added pieces
261 // because because the pieces added before are either:
262 // - Subsets of the domain of other MAFs in `this`, which are guranteed
263 // to be disjoint from `dom`, or
264 // - They are one of the pieces added for `funcB`, and we have been
265 // subtracting all such pieces from `dom`, so `dom` is disjoint from those
266 // pieces as well.
267 result.addPiece(dom, funcA.getOutputMatrix());
268 }
269
270 // Add parts of funcB which are not shared with funcA.
271 PresburgerSet dom = getDomain();
272 for (const MultiAffineFunction &funcB : func.pieces)
273 result.addPiece(funcB.getDomain().subtract(dom), funcB.getOutputMatrix());
274
275 return result;
276 }
277
278 /// A tiebreak function which breaks ties by comparing the outputs
279 /// lexicographically. If `lexMin` is true, then the ties are broken by
280 /// taking the lexicographically smaller output and otherwise, by taking the
281 /// lexicographically larger output.
282 template <bool lexMin>
tiebreakLex(const MultiAffineFunction & mafA,const MultiAffineFunction & mafB)283 static PresburgerSet tiebreakLex(const MultiAffineFunction &mafA,
284 const MultiAffineFunction &mafB) {
285 // TODO: Support local variables here.
286 assert(mafA.getDomainSpace().isCompatible(mafB.getDomainSpace()) &&
287 "Domain spaces should be compatible.");
288 assert(mafA.getNumOutputs() == mafB.getNumOutputs() &&
289 "Number of outputs of both functions should be same.");
290 assert(mafA.getDomain().getNumLocalVars() == 0 &&
291 "Local variables are not supported yet.");
292
293 PresburgerSpace compatibleSpace = mafA.getDomain().getSpaceWithoutLocals();
294 const PresburgerSpace &space = mafA.getDomain().getSpace();
295
296 // We first create the set `result`, corresponding to the set where output
297 // of mafA is lexicographically larger/smaller than mafB. This is done by
298 // creating a PresburgerSet with the following constraints:
299 //
300 // (outA[0] > outB[0]) U
301 // (outA[0] = outB[0], outA[1] > outA[1]) U
302 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
303 // ...
304 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
305 //
306 // where `n` is the number of outputs.
307 // If `lexMin` is set, the complement inequality is used:
308 //
309 // (outA[0] < outB[0]) U
310 // (outA[0] = outB[0], outA[1] < outA[1]) U
311 // (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
312 // ...
313 // (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
314 PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace);
315 IntegerPolyhedron levelSet(/*numReservedInequalities=*/1,
316 /*numReservedEqualities=*/mafA.getNumOutputs(),
317 /*numReservedCols=*/space.getNumVars() + 1, space);
318 for (unsigned level = 0; level < mafA.getNumOutputs(); ++level) {
319
320 // Create the expression `outA - outB` for this level.
321 SmallVector<int64_t, 8> subExpr =
322 subtract(mafA.getOutputExpr(level), mafB.getOutputExpr(level));
323
324 if (lexMin) {
325 // For lexMin, we add an upper bound of -1:
326 // outA - outB <= -1
327 // outA <= outB - 1
328 // outA < outB
329 levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, -1);
330 } else {
331 // For lexMax, we add a lower bound of 1:
332 // outA - outB >= 1
333 // outA > outB + 1
334 // outA > outB
335 levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, 1);
336 }
337
338 // Union the set with the result.
339 result.unionInPlace(levelSet);
340 // There is only 1 inequality in `levelSet`, so the index is always 0.
341 levelSet.removeInequality(0);
342 // Add equality `outA - outB == 0` for this level for next iteration.
343 levelSet.addEquality(subExpr);
344 }
345
346 // We then intersect `result` with the domain of mafA and mafB, to only
347 // tiebreak on the domain where both are defined.
348 result = result.intersect(PresburgerSet(mafA.getDomain()))
349 .intersect(PresburgerSet(mafB.getDomain()));
350
351 return result;
352 }
353
unionLexMin(const PWMAFunction & func)354 PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
355 return unionFunction(func, tiebreakLex</*lexMin=*/true>);
356 }
357
unionLexMax(const PWMAFunction & func)358 PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
359 return unionFunction(func, tiebreakLex</*lexMin=*/false>);
360 }
361