1 //===- Utils.cpp - General utilities for Presburger library ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Utility functions required by the Presburger Library.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Analysis/Presburger/Utils.h"
14 #include "mlir/Analysis/Presburger/IntegerRelation.h"
15 #include "mlir/Support/LogicalResult.h"
16 #include "mlir/Support/MathExtras.h"
17 
18 using namespace mlir;
19 using namespace presburger;
20 
21 /// Normalize a division's `dividend` and the `divisor` by their GCD. For
22 /// example: if the dividend and divisor are [2,0,4] and 4 respectively,
23 /// they get normalized to [1,0,2] and 2.
24 static void normalizeDivisionByGCD(SmallVectorImpl<int64_t> &dividend,
25                                    unsigned &divisor) {
26   if (divisor == 0 || dividend.empty())
27     return;
28   // We take the absolute value of dividend's coefficients to make sure that
29   // `gcd` is positive.
30   int64_t gcd =
31       llvm::greatestCommonDivisor(std::abs(dividend.front()), int64_t(divisor));
32 
33   // The reason for ignoring the constant term is as follows.
34   // For a division:
35   //      floor((a + m.f(x))/(m.d))
36   // It can be replaced by:
37   //      floor((floor(a/m) + f(x))/d)
38   // Since `{a/m}/d` in the dividend satisfies 0 <= {a/m}/d < 1/d, it will not
39   // influence the result of the floor division and thus, can be ignored.
40   for (size_t i = 1, m = dividend.size() - 1; i < m; i++) {
41     gcd = llvm::greatestCommonDivisor(std::abs(dividend[i]), gcd);
42     if (gcd == 1)
43       return;
44   }
45 
46   // Normalize the dividend and the denominator.
47   std::transform(dividend.begin(), dividend.end(), dividend.begin(),
48                  [gcd](int64_t &n) { return floorDiv(n, gcd); });
49   divisor /= gcd;
50 }
51 
52 /// Check if the pos^th identifier can be represented as a division using upper
53 /// bound inequality at position `ubIneq` and lower bound inequality at position
54 /// `lbIneq`.
55 ///
56 /// Let `id` be the pos^th identifier, then `id` is equivalent to
57 /// `expr floordiv divisor` if there are constraints of the form:
58 ///      0 <= expr - divisor * id <= divisor - 1
59 /// Rearranging, we have:
60 ///       divisor * id - expr + (divisor - 1) >= 0  <-- Lower bound for 'id'
61 ///      -divisor * id + expr                 >= 0  <-- Upper bound for 'id'
62 ///
63 /// For example:
64 ///     32*k >= 16*i + j - 31                 <-- Lower bound for 'k'
65 ///     32*k  <= 16*i + j                     <-- Upper bound for 'k'
66 ///     expr = 16*i + j, divisor = 32
67 ///     k = ( 16*i + j ) floordiv 32
68 ///
69 ///     4q >= i + j - 2                       <-- Lower bound for 'q'
70 ///     4q <= i + j + 1                       <-- Upper bound for 'q'
71 ///     expr = i + j + 1, divisor = 4
72 ///     q = (i + j + 1) floordiv 4
73 //
74 /// This function also supports detecting divisions from bounds that are
75 /// strictly tighter than the division bounds described above, since tighter
76 /// bounds imply the division bounds. For example:
77 ///     4q - i - j + 2 >= 0                       <-- Lower bound for 'q'
78 ///    -4q + i + j     >= 0                       <-- Tight upper bound for 'q'
79 ///
80 /// To extract floor divisions with tighter bounds, we assume that that the
81 /// constraints are of the form:
82 ///     c <= expr - divisior * id <= divisor - 1, where 0 <= c <= divisor - 1
83 /// Rearranging, we have:
84 ///     divisor * id - expr + (divisor - 1) >= 0  <-- Lower bound for 'id'
85 ///    -divisor * id + expr - c             >= 0  <-- Upper bound for 'id'
86 ///
87 /// If successful, `expr` is set to dividend of the division and `divisor` is
88 /// set to the denominator of the division. The final division expression is
89 /// normalized by GCD.
90 static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
91                                 unsigned ubIneq, unsigned lbIneq,
92                                 SmallVector<int64_t, 8> &expr,
93                                 unsigned &divisor) {
94 
95   assert(pos <= cst.getNumIds() && "Invalid identifier position");
96   assert(ubIneq <= cst.getNumInequalities() &&
97          "Invalid upper bound inequality position");
98   assert(lbIneq <= cst.getNumInequalities() &&
99          "Invalid upper bound inequality position");
100 
101   // Extract divisor from the lower bound.
102   divisor = cst.atIneq(lbIneq, pos);
103 
104   // First, check if the constraints are opposite of each other except the
105   // constant term.
106   unsigned i = 0, e = 0;
107   for (i = 0, e = cst.getNumIds(); i < e; ++i)
108     if (cst.atIneq(ubIneq, i) != -cst.atIneq(lbIneq, i))
109       break;
110 
111   if (i < e)
112     return failure();
113 
114   // Then, check if the constant term is of the proper form.
115   // Due to the form of the upper/lower bound inequalities, the sum of their
116   // constants is `divisor - 1 - c`. From this, we can extract c:
117   int64_t constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) +
118                         cst.atIneq(ubIneq, cst.getNumCols() - 1);
119   int64_t c = divisor - 1 - constantSum;
120 
121   // Check if `c` satisfies the condition `0 <= c <= divisor - 1`. This also
122   // implictly checks that `divisor` is positive.
123   if (!(c >= 0 && c <= divisor - 1))
124     return failure();
125 
126   // The inequality pair can be used to extract the division.
127   // Set `expr` to the dividend of the division except the constant term, which
128   // is set below.
129   expr.resize(cst.getNumCols(), 0);
130   for (i = 0, e = cst.getNumIds(); i < e; ++i)
131     if (i != pos)
132       expr[i] = cst.atIneq(ubIneq, i);
133 
134   // From the upper bound inequality's form, its constant term is equal to the
135   // constant term of `expr`, minus `c`. From this,
136   // constant term of `expr` = constant term of upper bound + `c`.
137   expr.back() = cst.atIneq(ubIneq, cst.getNumCols() - 1) + c;
138   normalizeDivisionByGCD(expr, divisor);
139 
140   return success();
141 }
142 
143 /// Check if the pos^th identifier can be represented as a division using
144 /// equality at position `eqInd`.
145 ///
146 /// For example:
147 ///     32*k == 16*i + j - 31                 <-- `eqInd` for 'k'
148 ///     expr = 16*i + j - 31, divisor = 32
149 ///     k = (16*i + j - 31) floordiv 32
150 ///
151 /// If successful, `expr` is set to dividend of the division and `divisor` is
152 /// set to the denominator of the division. The final division expression is
153 /// normalized by GCD.
154 static LogicalResult getDivRepr(const IntegerRelation &cst, unsigned pos,
155                                 unsigned eqInd, SmallVector<int64_t, 8> &expr,
156                                 unsigned &divisor) {
157 
158   assert(pos <= cst.getNumIds() && "Invalid identifier position");
159   assert(eqInd <= cst.getNumEqualities() && "Invalid equality position");
160 
161   // Extract divisor, the divisor can be negative and hence its sign information
162   // is stored in `signDiv` to reverse the sign of dividend's coefficients.
163   // Equality must involve the pos-th variable and hence `tempDiv` != 0.
164   int64_t tempDiv = cst.atEq(eqInd, pos);
165   if (tempDiv == 0)
166     return failure();
167   int64_t signDiv = tempDiv < 0 ? -1 : 1;
168 
169   // The divisor is always a positive integer.
170   divisor = tempDiv * signDiv;
171 
172   expr.resize(cst.getNumCols(), 0);
173   for (unsigned i = 0, e = cst.getNumIds(); i < e; ++i)
174     if (i != pos)
175       expr[i] = signDiv * cst.atEq(eqInd, i);
176 
177   expr.back() = signDiv * cst.atEq(eqInd, cst.getNumCols() - 1);
178   normalizeDivisionByGCD(expr, divisor);
179 
180   return success();
181 }
182 
183 // Returns `false` if the constraints depends on a variable for which an
184 // explicit representation has not been found yet, otherwise returns `true`.
185 static bool checkExplicitRepresentation(const IntegerRelation &cst,
186                                         ArrayRef<bool> foundRepr,
187                                         ArrayRef<int64_t> dividend,
188                                         unsigned pos) {
189   // Exit to avoid circular dependencies between divisions.
190   for (unsigned c = 0, e = cst.getNumIds(); c < e; ++c) {
191     if (c == pos)
192       continue;
193 
194     if (!foundRepr[c] && dividend[c] != 0) {
195       // Expression can't be constructed as it depends on a yet unknown
196       // identifier.
197       //
198       // TODO: Visit/compute the identifiers in an order so that this doesn't
199       // happen. More complex but much more efficient.
200       return false;
201     }
202   }
203 
204   return true;
205 }
206 
207 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
208 /// function of other identifiers (where the divisor is a positive constant).
209 /// `foundRepr` contains a boolean for each identifier indicating if the
210 /// explicit representation for that identifier has already been computed.
211 /// Returns the `MaybeLocalRepr` struct which contains the indices of the
212 /// constraints that can be expressed as a floordiv of an affine function. If
213 /// the representation could be computed, `dividend` and `denominator` are set.
214 /// If the representation could not be computed, the kind attribute in
215 /// `MaybeLocalRepr` is set to None.
216 MaybeLocalRepr presburger::computeSingleVarRepr(
217     const IntegerRelation &cst, ArrayRef<bool> foundRepr, unsigned pos,
218     SmallVector<int64_t, 8> &dividend, unsigned &divisor) {
219   assert(pos < cst.getNumIds() && "invalid position");
220   assert(foundRepr.size() == cst.getNumIds() &&
221          "Size of foundRepr does not match total number of variables");
222 
223   SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
224   cst.getLowerAndUpperBoundIndices(pos, &lbIndices, &ubIndices, &eqIndices);
225   MaybeLocalRepr repr{};
226 
227   for (unsigned ubPos : ubIndices) {
228     for (unsigned lbPos : lbIndices) {
229       // Attempt to get divison representation from ubPos, lbPos.
230       if (failed(getDivRepr(cst, pos, ubPos, lbPos, dividend, divisor)))
231         continue;
232 
233       if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos))
234         continue;
235 
236       repr.kind = ReprKind::Inequality;
237       repr.repr.inequalityPair = {ubPos, lbPos};
238       return repr;
239     }
240   }
241   for (unsigned eqPos : eqIndices) {
242     // Attempt to get divison representation from eqPos.
243     if (failed(getDivRepr(cst, pos, eqPos, dividend, divisor)))
244       continue;
245 
246     if (!checkExplicitRepresentation(cst, foundRepr, dividend, pos))
247       continue;
248 
249     repr.kind = ReprKind::Equality;
250     repr.repr.equalityIdx = eqPos;
251     return repr;
252   }
253   return repr;
254 }
255 
256 void presburger::removeDuplicateDivs(
257     std::vector<SmallVector<int64_t, 8>> &divs,
258     SmallVectorImpl<unsigned> &denoms, unsigned localOffset,
259     llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
260 
261   // Find and merge duplicate divisions.
262   // TODO: Add division normalization to support divisions that differ by
263   // a constant.
264   // TODO: Add division ordering such that a division representation for local
265   // identifier at position `i` only depends on local identifiers at position <
266   // `i`. This would make sure that all divisions depending on other local
267   // variables that can be merged, are merged.
268   for (unsigned i = 0; i < divs.size(); ++i) {
269     // Check if a division representation exists for the `i^th` local id.
270     if (denoms[i] == 0)
271       continue;
272     // Check if a division exists which is a duplicate of the division at `i`.
273     for (unsigned j = i + 1; j < divs.size(); ++j) {
274       // Check if a division representation exists for the `j^th` local id.
275       if (denoms[j] == 0)
276         continue;
277       // Check if the denominators match.
278       if (denoms[i] != denoms[j])
279         continue;
280       // Check if the representations are equal.
281       if (divs[i] != divs[j])
282         continue;
283 
284       // Merge divisions at position `j` into division at position `i`. If
285       // merge fails, do not merge these divs.
286       bool mergeResult = merge(i, j);
287       if (!mergeResult)
288         continue;
289 
290       // Update division information to reflect merging.
291       for (unsigned k = 0, g = divs.size(); k < g; ++k) {
292         SmallVector<int64_t, 8> &div = divs[k];
293         if (denoms[k] != 0) {
294           div[localOffset + i] += div[localOffset + j];
295           div.erase(div.begin() + localOffset + j);
296         }
297       }
298 
299       divs.erase(divs.begin() + j);
300       denoms.erase(denoms.begin() + j);
301       // Since `j` can never be zero, we do not need to worry about overflows.
302       --j;
303     }
304   }
305 }
306 
307 void presburger::mergeLocalIds(
308     IntegerRelation &relA, IntegerRelation &relB,
309     llvm::function_ref<bool(unsigned i, unsigned j)> merge) {
310   assert(relA.getSpace().isCompatible(relB.getSpace()) &&
311          "Spaces should be compatible.");
312 
313   // Merge local ids of relA and relB without using division information,
314   // i.e. append local ids of `relB` to `relA` and insert local ids of `relA`
315   // to `relB` at start of its local ids.
316   unsigned initLocals = relA.getNumLocalIds();
317   relA.insertId(IdKind::Local, relA.getNumLocalIds(), relB.getNumLocalIds());
318   relB.insertId(IdKind::Local, 0, initLocals);
319 
320   // Get division representations from each rel.
321   std::vector<SmallVector<int64_t, 8>> divsA, divsB;
322   SmallVector<unsigned, 4> denomsA, denomsB;
323   relA.getLocalReprs(divsA, denomsA);
324   relB.getLocalReprs(divsB, denomsB);
325 
326   // Copy division information for relB into `divsA` and `denomsA`, so that
327   // these have the combined division information of both rels. Since newly
328   // added local variables in relA and relB have no constraints, they will not
329   // have any division representation.
330   std::copy(divsB.begin() + initLocals, divsB.end(),
331             divsA.begin() + initLocals);
332   std::copy(denomsB.begin() + initLocals, denomsB.end(),
333             denomsA.begin() + initLocals);
334 
335   // Merge all divisions by removing duplicate divisions.
336   unsigned localOffset = relA.getIdKindOffset(IdKind::Local);
337   presburger::removeDuplicateDivs(divsA, denomsA, localOffset, merge);
338 }
339 
340 int64_t presburger::gcdRange(ArrayRef<int64_t> range) {
341   int64_t gcd = 0;
342   for (int64_t elem : range) {
343     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(elem));
344     if (gcd == 1)
345       return gcd;
346   }
347   return gcd;
348 }
349 
350 int64_t presburger::normalizeRange(MutableArrayRef<int64_t> range) {
351   int64_t gcd = gcdRange(range);
352   if (gcd == 0 || gcd == 1)
353     return gcd;
354   for (int64_t &elem : range)
355     elem /= gcd;
356   return gcd;
357 }
358 
359 void presburger::normalizeDiv(MutableArrayRef<int64_t> num, int64_t &denom) {
360   assert(denom > 0 && "denom must be positive!");
361   int64_t gcd = llvm::greatestCommonDivisor(gcdRange(num), denom);
362   for (int64_t &coeff : num)
363     coeff /= gcd;
364   denom /= gcd;
365 }
366 
367 SmallVector<int64_t, 8> presburger::getNegatedCoeffs(ArrayRef<int64_t> coeffs) {
368   SmallVector<int64_t, 8> negatedCoeffs;
369   negatedCoeffs.reserve(coeffs.size());
370   for (int64_t coeff : coeffs)
371     negatedCoeffs.emplace_back(-coeff);
372   return negatedCoeffs;
373 }
374 
375 SmallVector<int64_t, 8> presburger::getComplementIneq(ArrayRef<int64_t> ineq) {
376   SmallVector<int64_t, 8> coeffs;
377   coeffs.reserve(ineq.size());
378   for (int64_t coeff : ineq)
379     coeffs.emplace_back(-coeff);
380   --coeffs.back();
381   return coeffs;
382 }
383