1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Structures for affine/polyhedral analysis of affine dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
14 #include "mlir/Analysis/Presburger/LinearTransform.h"
15 #include "mlir/Analysis/Presburger/Simplex.h"
16 #include "mlir/Analysis/Presburger/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
20 #include "mlir/IR/AffineExprVisitor.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/Support/LLVM.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 #define DEBUG_TYPE "affine-structures"
31 
32 using namespace mlir;
33 using namespace presburger;
34 
35 namespace {
36 
37 // See comments for SimpleAffineExprFlattener.
38 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
39 // constraint information associated with mod's, floordiv's, and ceildiv's
40 // in FlatAffineValueConstraints 'localVarCst'.
41 struct AffineExprFlattener : public SimpleAffineExprFlattener {
42 public:
43   // Constraints connecting newly introduced local variables (for mod's and
44   // div's) to existing (dimensional and symbolic) ones. These are always
45   // inequalities.
46   IntegerPolyhedron localVarCst;
47 
AffineExprFlattener__anon77af06760111::AffineExprFlattener48   AffineExprFlattener(unsigned nDims, unsigned nSymbols)
49       : SimpleAffineExprFlattener(nDims, nSymbols),
50         localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {}
51 
52 private:
53   // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
54   // The local variable added is always a floordiv of a pure add/mul affine
55   // function of other variables, coefficients of which are specified in
56   // `dividend' and with respect to the positive constant `divisor'. localExpr
57   // is the simplified tree expression (AffineExpr) corresponding to the
58   // quantifier.
addLocalFloorDivId__anon77af06760111::AffineExprFlattener59   void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
60                           AffineExpr localExpr) override {
61     SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
62     // Update localVarCst.
63     localVarCst.addLocalFloorDiv(dividend, divisor);
64   }
65 };
66 
67 } // namespace
68 
69 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
70 // flattened (i.e., semi-affine expressions not handled yet).
71 static LogicalResult
getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs,unsigned numDims,unsigned numSymbols,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineValueConstraints * localVarCst)72 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
73                         unsigned numSymbols,
74                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
75                         FlatAffineValueConstraints *localVarCst) {
76   if (exprs.empty()) {
77     localVarCst->reset(numDims, numSymbols);
78     return success();
79   }
80 
81   AffineExprFlattener flattener(numDims, numSymbols);
82   // Use the same flattener to simplify each expression successively. This way
83   // local variables / expressions are shared.
84   for (auto expr : exprs) {
85     if (!expr.isPureAffine())
86       return failure();
87 
88     flattener.walkPostOrder(expr);
89   }
90 
91   assert(flattener.operandExprStack.size() == exprs.size());
92   flattenedExprs->clear();
93   flattenedExprs->assign(flattener.operandExprStack.begin(),
94                          flattener.operandExprStack.end());
95 
96   if (localVarCst)
97     localVarCst->clearAndCopyFrom(flattener.localVarCst);
98 
99   return success();
100 }
101 
102 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
103 // be flattened (semi-affine expressions not handled yet).
104 LogicalResult
getFlattenedAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols,SmallVectorImpl<int64_t> * flattenedExpr,FlatAffineValueConstraints * localVarCst)105 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
106                              unsigned numSymbols,
107                              SmallVectorImpl<int64_t> *flattenedExpr,
108                              FlatAffineValueConstraints *localVarCst) {
109   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
110   LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
111                                                 &flattenedExprs, localVarCst);
112   *flattenedExpr = flattenedExprs[0];
113   return ret;
114 }
115 
116 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
117 /// flattened (i.e., semi-affine expressions not handled yet).
getFlattenedAffineExprs(AffineMap map,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineValueConstraints * localVarCst)118 LogicalResult mlir::getFlattenedAffineExprs(
119     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
120     FlatAffineValueConstraints *localVarCst) {
121   if (map.getNumResults() == 0) {
122     localVarCst->reset(map.getNumDims(), map.getNumSymbols());
123     return success();
124   }
125   return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
126                                    map.getNumSymbols(), flattenedExprs,
127                                    localVarCst);
128 }
129 
getFlattenedAffineExprs(IntegerSet set,std::vector<SmallVector<int64_t,8>> * flattenedExprs,FlatAffineValueConstraints * localVarCst)130 LogicalResult mlir::getFlattenedAffineExprs(
131     IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
132     FlatAffineValueConstraints *localVarCst) {
133   if (set.getNumConstraints() == 0) {
134     localVarCst->reset(set.getNumDims(), set.getNumSymbols());
135     return success();
136   }
137   return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
138                                    set.getNumSymbols(), flattenedExprs,
139                                    localVarCst);
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // FlatAffineConstraints / FlatAffineValueConstraints.
144 //===----------------------------------------------------------------------===//
145 
146 std::unique_ptr<FlatAffineValueConstraints>
clone() const147 FlatAffineValueConstraints::clone() const {
148   return std::make_unique<FlatAffineValueConstraints>(*this);
149 }
150 
151 // Construct from an IntegerSet.
FlatAffineValueConstraints(IntegerSet set)152 FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set)
153     : IntegerPolyhedron(set.getNumInequalities(), set.getNumEqualities(),
154                         set.getNumDims() + set.getNumSymbols() + 1,
155                         PresburgerSpace::getSetSpace(set.getNumDims(),
156                                                      set.getNumSymbols(),
157                                                      /*numLocals=*/0)) {
158 
159   // Resize values.
160   values.resize(getNumDimAndSymbolVars(), None);
161 
162   // Flatten expressions and add them to the constraint system.
163   std::vector<SmallVector<int64_t, 8>> flatExprs;
164   FlatAffineValueConstraints localVarCst;
165   if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
166     assert(false && "flattening unimplemented for semi-affine integer sets");
167     return;
168   }
169   assert(flatExprs.size() == set.getNumConstraints());
170   insertVar(VarKind::Local, getNumVarKind(VarKind::Local),
171             /*num=*/localVarCst.getNumLocalVars());
172 
173   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
174     const auto &flatExpr = flatExprs[i];
175     assert(flatExpr.size() == getNumCols());
176     if (set.getEqFlags()[i]) {
177       addEquality(flatExpr);
178     } else {
179       addInequality(flatExpr);
180     }
181   }
182   // Add the other constraints involving local vars from flattening.
183   append(localVarCst);
184 }
185 
186 // Construct a hyperrectangular constraint set from ValueRanges that represent
187 // induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are
188 // expected to match one to one. The order of variables and constraints is:
189 //
190 // ivs | lbs | ubs | eq/ineq
191 // ----+-----+-----+---------
192 //   1   -1     0      >= 0
193 // ----+-----+-----+---------
194 //  -1    0     1      >= 0
195 //
196 // All dimensions as set as VarKind::SetDim.
197 FlatAffineValueConstraints
getHyperrectangular(ValueRange ivs,ValueRange lbs,ValueRange ubs)198 FlatAffineValueConstraints::getHyperrectangular(ValueRange ivs, ValueRange lbs,
199                                                 ValueRange ubs) {
200   FlatAffineValueConstraints res;
201   unsigned nIvs = ivs.size();
202   assert(nIvs == lbs.size() && "expected as many lower bounds as ivs");
203   assert(nIvs == ubs.size() && "expected as many upper bounds as ivs");
204 
205   if (nIvs == 0)
206     return res;
207 
208   res.appendDimVar(ivs);
209   unsigned lbsStart = res.appendDimVar(lbs);
210   unsigned ubsStart = res.appendDimVar(ubs);
211 
212   MLIRContext *ctx = ivs.front().getContext();
213   for (int ivIdx = 0, e = nIvs; ivIdx < e; ++ivIdx) {
214     // iv - lb >= 0
215     AffineMap lb = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
216                                   getAffineDimExpr(lbsStart + ivIdx, ctx));
217     if (failed(res.addBound(BoundType::LB, ivIdx, lb)))
218       llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
219     // -iv + ub >= 0
220     AffineMap ub = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
221                                   getAffineDimExpr(ubsStart + ivIdx, ctx));
222     if (failed(res.addBound(BoundType::UB, ivIdx, ub)))
223       llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
224   }
225   return res;
226 }
227 
reset(unsigned numReservedInequalities,unsigned numReservedEqualities,unsigned newNumReservedCols,unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals)228 void FlatAffineValueConstraints::reset(unsigned numReservedInequalities,
229                                        unsigned numReservedEqualities,
230                                        unsigned newNumReservedCols,
231                                        unsigned newNumDims,
232                                        unsigned newNumSymbols,
233                                        unsigned newNumLocals) {
234   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
235          "minimum 1 column");
236   *this = FlatAffineValueConstraints(numReservedInequalities,
237                                      numReservedEqualities, newNumReservedCols,
238                                      newNumDims, newNumSymbols, newNumLocals);
239 }
240 
reset(unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals)241 void FlatAffineValueConstraints::reset(unsigned newNumDims,
242                                        unsigned newNumSymbols,
243                                        unsigned newNumLocals) {
244   reset(/*numReservedInequalities=*/0, /*numReservedEqualities=*/0,
245         /*numReservedCols=*/newNumDims + newNumSymbols + newNumLocals + 1,
246         newNumDims, newNumSymbols, newNumLocals);
247 }
248 
reset(unsigned numReservedInequalities,unsigned numReservedEqualities,unsigned newNumReservedCols,unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> valArgs)249 void FlatAffineValueConstraints::reset(
250     unsigned numReservedInequalities, unsigned numReservedEqualities,
251     unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols,
252     unsigned newNumLocals, ArrayRef<Value> valArgs) {
253   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
254          "minimum 1 column");
255   SmallVector<Optional<Value>, 8> newVals;
256   if (!valArgs.empty())
257     newVals.assign(valArgs.begin(), valArgs.end());
258 
259   *this = FlatAffineValueConstraints(
260       numReservedInequalities, numReservedEqualities, newNumReservedCols,
261       newNumDims, newNumSymbols, newNumLocals, newVals);
262 }
263 
reset(unsigned newNumDims,unsigned newNumSymbols,unsigned newNumLocals,ArrayRef<Value> valArgs)264 void FlatAffineValueConstraints::reset(unsigned newNumDims,
265                                        unsigned newNumSymbols,
266                                        unsigned newNumLocals,
267                                        ArrayRef<Value> valArgs) {
268   reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
269         newNumSymbols, newNumLocals, valArgs);
270 }
271 
appendDimVar(ValueRange vals)272 unsigned FlatAffineValueConstraints::appendDimVar(ValueRange vals) {
273   unsigned pos = getNumDimVars();
274   insertVar(VarKind::SetDim, pos, vals);
275   return pos;
276 }
277 
appendSymbolVar(ValueRange vals)278 unsigned FlatAffineValueConstraints::appendSymbolVar(ValueRange vals) {
279   unsigned pos = getNumSymbolVars();
280   insertVar(VarKind::Symbol, pos, vals);
281   return pos;
282 }
283 
insertDimVar(unsigned pos,ValueRange vals)284 unsigned FlatAffineValueConstraints::insertDimVar(unsigned pos,
285                                                   ValueRange vals) {
286   return insertVar(VarKind::SetDim, pos, vals);
287 }
288 
insertSymbolVar(unsigned pos,ValueRange vals)289 unsigned FlatAffineValueConstraints::insertSymbolVar(unsigned pos,
290                                                      ValueRange vals) {
291   return insertVar(VarKind::Symbol, pos, vals);
292 }
293 
insertVar(VarKind kind,unsigned pos,unsigned num)294 unsigned FlatAffineValueConstraints::insertVar(VarKind kind, unsigned pos,
295                                                unsigned num) {
296   unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
297 
298   if (kind != VarKind::Local) {
299     values.insert(values.begin() + absolutePos, num, None);
300     assert(values.size() == getNumDimAndSymbolVars());
301   }
302 
303   return absolutePos;
304 }
305 
insertVar(VarKind kind,unsigned pos,ValueRange vals)306 unsigned FlatAffineValueConstraints::insertVar(VarKind kind, unsigned pos,
307                                                ValueRange vals) {
308   assert(!vals.empty() && "expected ValueRange with Values.");
309   assert(kind != VarKind::Local &&
310          "values cannot be attached to local variables.");
311   unsigned num = vals.size();
312   unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
313 
314   // If a Value is provided, insert it; otherwise use None.
315   for (unsigned i = 0; i < num; ++i)
316     values.insert(values.begin() + absolutePos + i,
317                   vals[i] ? Optional<Value>(vals[i]) : None);
318 
319   assert(values.size() == getNumDimAndSymbolVars());
320   return absolutePos;
321 }
322 
hasValues() const323 bool FlatAffineValueConstraints::hasValues() const {
324   return llvm::any_of(
325       values, [](const Optional<Value> &var) { return var.has_value(); });
326 }
327 
328 /// Checks if two constraint systems are in the same space, i.e., if they are
329 /// associated with the same set of variables, appearing in the same order.
areVarsAligned(const FlatAffineValueConstraints & a,const FlatAffineValueConstraints & b)330 static bool areVarsAligned(const FlatAffineValueConstraints &a,
331                            const FlatAffineValueConstraints &b) {
332   return a.getNumDimVars() == b.getNumDimVars() &&
333          a.getNumSymbolVars() == b.getNumSymbolVars() &&
334          a.getNumVars() == b.getNumVars() &&
335          a.getMaybeValues().equals(b.getMaybeValues());
336 }
337 
338 /// Calls areVarsAligned to check if two constraint systems have the same set
339 /// of variables in the same order.
areVarsAlignedWithOther(const FlatAffineValueConstraints & other)340 bool FlatAffineValueConstraints::areVarsAlignedWithOther(
341     const FlatAffineValueConstraints &other) {
342   return areVarsAligned(*this, other);
343 }
344 
345 /// Checks if the SSA values associated with `cst`'s variables in range
346 /// [start, end) are unique.
areVarsUnique(const FlatAffineValueConstraints & cst,unsigned start,unsigned end)347 static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
348     const FlatAffineValueConstraints &cst, unsigned start, unsigned end) {
349 
350   assert(start <= cst.getNumDimAndSymbolVars() &&
351          "Start position out of bounds");
352   assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds");
353 
354   if (start >= end)
355     return true;
356 
357   SmallPtrSet<Value, 8> uniqueVars;
358   ArrayRef<Optional<Value>> maybeValues =
359       cst.getMaybeValues().slice(start, end - start);
360   for (Optional<Value> val : maybeValues) {
361     if (val && !uniqueVars.insert(*val).second)
362       return false;
363   }
364   return true;
365 }
366 
367 /// Checks if the SSA values associated with `cst`'s variables are unique.
368 static bool LLVM_ATTRIBUTE_UNUSED
areVarsUnique(const FlatAffineValueConstraints & cst)369 areVarsUnique(const FlatAffineValueConstraints &cst) {
370   return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars());
371 }
372 
373 /// Checks if the SSA values associated with `cst`'s variables of kind `kind`
374 /// are unique.
375 static bool LLVM_ATTRIBUTE_UNUSED
areVarsUnique(const FlatAffineValueConstraints & cst,VarKind kind)376 areVarsUnique(const FlatAffineValueConstraints &cst, VarKind kind) {
377 
378   if (kind == VarKind::SetDim)
379     return areVarsUnique(cst, 0, cst.getNumDimVars());
380   if (kind == VarKind::Symbol)
381     return areVarsUnique(cst, cst.getNumDimVars(),
382                          cst.getNumDimAndSymbolVars());
383   llvm_unreachable("Unexpected VarKind");
384 }
385 
386 /// Merge and align the variables of A and B starting at 'offset', so that
387 /// both constraint systems get the union of the contained variables that is
388 /// dimension-wise and symbol-wise unique; both constraint systems are updated
389 /// so that they have the union of all variables, with A's original
390 /// variables appearing first followed by any of B's variables that didn't
391 /// appear in A. Local variables in B that have the same division
392 /// representation as local variables in A are merged into one.
393 //  E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
394 //        Output: both A, B have (%i, %j, %k) [%M, %N, %P]
mergeAndAlignVars(unsigned offset,FlatAffineValueConstraints * a,FlatAffineValueConstraints * b)395 static void mergeAndAlignVars(unsigned offset, FlatAffineValueConstraints *a,
396                               FlatAffineValueConstraints *b) {
397   assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars());
398   // A merge/align isn't meaningful if a cst's vars aren't distinct.
399   assert(areVarsUnique(*a) && "A's values aren't unique");
400   assert(areVarsUnique(*b) && "B's values aren't unique");
401 
402   assert(std::all_of(a->getMaybeValues().begin() + offset,
403                      a->getMaybeValues().end(),
404                      [](Optional<Value> var) { return var.has_value(); }));
405 
406   assert(std::all_of(b->getMaybeValues().begin() + offset,
407                      b->getMaybeValues().end(),
408                      [](Optional<Value> var) { return var.has_value(); }));
409 
410   SmallVector<Value, 4> aDimValues;
411   a->getValues(offset, a->getNumDimVars(), &aDimValues);
412 
413   {
414     // Merge dims from A into B.
415     unsigned d = offset;
416     for (auto aDimValue : aDimValues) {
417       unsigned loc;
418       if (b->findVar(aDimValue, &loc)) {
419         assert(loc >= offset && "A's dim appears in B's aligned range");
420         assert(loc < b->getNumDimVars() &&
421                "A's dim appears in B's non-dim position");
422         b->swapVar(d, loc);
423       } else {
424         b->insertDimVar(d, aDimValue);
425       }
426       d++;
427     }
428     // Dimensions that are in B, but not in A, are added at the end.
429     for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) {
430       a->appendDimVar(b->getValue(t));
431     }
432     assert(a->getNumDimVars() == b->getNumDimVars() &&
433            "expected same number of dims");
434   }
435 
436   // Merge and align symbols of A and B
437   a->mergeSymbolVars(*b);
438   // Merge and align locals of A and B
439   a->mergeLocalVars(*b);
440 
441   assert(areVarsAligned(*a, *b) && "IDs expected to be aligned");
442 }
443 
444 // Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'.
mergeAndAlignVarsWithOther(unsigned offset,FlatAffineValueConstraints * other)445 void FlatAffineValueConstraints::mergeAndAlignVarsWithOther(
446     unsigned offset, FlatAffineValueConstraints *other) {
447   mergeAndAlignVars(offset, this, other);
448 }
449 
450 LogicalResult
composeMap(const AffineValueMap * vMap)451 FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) {
452   return composeMatchingMap(
453       computeAlignedMap(vMap->getAffineMap(), vMap->getOperands()));
454 }
455 
456 // Similar to `composeMap` except that no Values need be associated with the
457 // constraint system nor are they looked at -- the dimensions and symbols of
458 // `other` are expected to correspond 1:1 to `this` system.
composeMatchingMap(AffineMap other)459 LogicalResult FlatAffineValueConstraints::composeMatchingMap(AffineMap other) {
460   assert(other.getNumDims() == getNumDimVars() && "dim mismatch");
461   assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
462 
463   std::vector<SmallVector<int64_t, 8>> flatExprs;
464   if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
465     return failure();
466   assert(flatExprs.size() == other.getNumResults());
467 
468   // Add dimensions corresponding to the map's results.
469   insertDimVar(/*pos=*/0, /*num=*/other.getNumResults());
470 
471   // We add one equality for each result connecting the result dim of the map to
472   // the other variables.
473   // E.g.: if the expression is 16*i0 + i1, and this is the r^th
474   // iteration/result of the value map, we are adding the equality:
475   // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
476   // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
477   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
478     const auto &flatExpr = flatExprs[r];
479     assert(flatExpr.size() >= other.getNumInputs() + 1);
480 
481     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
482     // Set the coefficient for this result to one.
483     eqToAdd[r] = 1;
484 
485     // Dims and symbols.
486     for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
487       // Negate `eq[r]` since the newly added dimension will be set to this one.
488       eqToAdd[e + i] = -flatExpr[i];
489     }
490     // Local columns of `eq` are at the beginning.
491     unsigned j = getNumDimVars() + getNumSymbolVars();
492     unsigned end = flatExpr.size() - 1;
493     for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
494       eqToAdd[j] = -flatExpr[i];
495     }
496 
497     // Constant term.
498     eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
499 
500     // Add the equality connecting the result of the map to this constraint set.
501     addEquality(eqToAdd);
502   }
503 
504   return success();
505 }
506 
507 // Turn a symbol into a dimension.
turnSymbolIntoDim(FlatAffineValueConstraints * cst,Value value)508 static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value value) {
509   unsigned pos;
510   if (cst->findVar(value, &pos) && pos >= cst->getNumDimVars() &&
511       pos < cst->getNumDimAndSymbolVars()) {
512     cst->swapVar(pos, cst->getNumDimVars());
513     cst->setDimSymbolSeparation(cst->getNumSymbolVars() - 1);
514   }
515 }
516 
517 /// Merge and align symbols of `this` and `other` such that both get union of
518 /// of symbols that are unique. Symbols in `this` and `other` should be
519 /// unique. Symbols with Value as `None` are considered to be inequal to all
520 /// other symbols.
mergeSymbolVars(FlatAffineValueConstraints & other)521 void FlatAffineValueConstraints::mergeSymbolVars(
522     FlatAffineValueConstraints &other) {
523 
524   assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique");
525   assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique");
526 
527   SmallVector<Value, 4> aSymValues;
528   getValues(getNumDimVars(), getNumDimAndSymbolVars(), &aSymValues);
529 
530   // Merge symbols: merge symbols into `other` first from `this`.
531   unsigned s = other.getNumDimVars();
532   for (Value aSymValue : aSymValues) {
533     unsigned loc;
534     // If the var is a symbol in `other`, then align it, otherwise assume that
535     // it is a new symbol
536     if (other.findVar(aSymValue, &loc) && loc >= other.getNumDimVars() &&
537         loc < other.getNumDimAndSymbolVars())
538       other.swapVar(s, loc);
539     else
540       other.insertSymbolVar(s - other.getNumDimVars(), aSymValue);
541     s++;
542   }
543 
544   // Symbols that are in other, but not in this, are added at the end.
545   for (unsigned t = other.getNumDimVars() + getNumSymbolVars(),
546                 e = other.getNumDimAndSymbolVars();
547        t < e; t++)
548     insertSymbolVar(getNumSymbolVars(), other.getValue(t));
549 
550   assert(getNumSymbolVars() == other.getNumSymbolVars() &&
551          "expected same number of symbols");
552   assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique");
553   assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique");
554 }
555 
556 // Changes all symbol variables which are loop IVs to dim variables.
convertLoopIVSymbolsToDims()557 void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() {
558   // Gather all symbols which are loop IVs.
559   SmallVector<Value, 4> loopIVs;
560   for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++) {
561     if (hasValue(i) && getForInductionVarOwner(getValue(i)))
562       loopIVs.push_back(getValue(i));
563   }
564   // Turn each symbol in 'loopIVs' into a dim variable.
565   for (auto iv : loopIVs) {
566     turnSymbolIntoDim(this, iv);
567   }
568 }
569 
addInductionVarOrTerminalSymbol(Value val)570 void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) {
571   if (containsVar(val))
572     return;
573 
574   // Caller is expected to fully compose map/operands if necessary.
575   assert((isTopLevelValue(val) || isForInductionVar(val)) &&
576          "non-terminal symbol / loop IV expected");
577   // Outer loop IVs could be used in forOp's bounds.
578   if (auto loop = getForInductionVarOwner(val)) {
579     appendDimVar(val);
580     if (failed(this->addAffineForOpDomain(loop)))
581       LLVM_DEBUG(
582           loop.emitWarning("failed to add domain info to constraint system"));
583     return;
584   }
585   // Add top level symbol.
586   appendSymbolVar(val);
587   // Check if the symbol is a constant.
588   if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>())
589     addBound(BoundType::EQ, val, constOp.value());
590 }
591 
592 LogicalResult
addAffineForOpDomain(AffineForOp forOp)593 FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
594   unsigned pos;
595   // Pre-condition for this method.
596   if (!findVar(forOp.getInductionVar(), &pos)) {
597     assert(false && "Value not found");
598     return failure();
599   }
600 
601   int64_t step = forOp.getStep();
602   if (step != 1) {
603     if (!forOp.hasConstantLowerBound())
604       LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated"));
605     else {
606       // Add constraints for the stride.
607       // (iv - lb) % step = 0 can be written as:
608       // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
609       // Add local variable 'q' and add the above equality.
610       // The first constraint is q = (iv - lb) floordiv step
611       SmallVector<int64_t, 8> dividend(getNumCols(), 0);
612       int64_t lb = forOp.getConstantLowerBound();
613       dividend[pos] = 1;
614       dividend.back() -= lb;
615       addLocalFloorDiv(dividend, step);
616       // Second constraint: (iv - lb) - step * q = 0.
617       SmallVector<int64_t, 8> eq(getNumCols(), 0);
618       eq[pos] = 1;
619       eq.back() -= lb;
620       // For the local var just added above.
621       eq[getNumCols() - 2] = -step;
622       addEquality(eq);
623     }
624   }
625 
626   if (forOp.hasConstantLowerBound()) {
627     addBound(BoundType::LB, pos, forOp.getConstantLowerBound());
628   } else {
629     // Non-constant lower bound case.
630     if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(),
631                         forOp.getLowerBoundOperands())))
632       return failure();
633   }
634 
635   if (forOp.hasConstantUpperBound()) {
636     addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1);
637     return success();
638   }
639   // Non-constant upper bound case.
640   return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(),
641                   forOp.getUpperBoundOperands());
642 }
643 
644 LogicalResult
addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,ArrayRef<AffineMap> ubMaps,ArrayRef<Value> operands)645 FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
646                                                    ArrayRef<AffineMap> ubMaps,
647                                                    ArrayRef<Value> operands) {
648   assert(lbMaps.size() == ubMaps.size());
649   assert(lbMaps.size() <= getNumDimVars());
650 
651   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
652     AffineMap lbMap = lbMaps[i];
653     AffineMap ubMap = ubMaps[i];
654     assert(!lbMap || lbMap.getNumInputs() == operands.size());
655     assert(!ubMap || ubMap.getNumInputs() == operands.size());
656 
657     // Check if this slice is just an equality along this dimension. If so,
658     // retrieve the existing loop it equates to and add it to the system.
659     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
660         ubMap.getNumResults() == 1 &&
661         lbMap.getResult(0) + 1 == ubMap.getResult(0) &&
662         // The condition above will be true for maps describing a single
663         // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
664         // Make sure we skip those cases by checking that the lb result is not
665         // just a constant.
666         !lbMap.getResult(0).isa<AffineConstantExpr>()) {
667       // Limited support: we expect the lb result to be just a loop dimension.
668       // Not supported otherwise for now.
669       AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
670       if (!result)
671         return failure();
672 
673       AffineForOp loop =
674           getForInductionVarOwner(operands[result.getPosition()]);
675       if (!loop)
676         return failure();
677 
678       if (failed(addAffineForOpDomain(loop)))
679         return failure();
680       continue;
681     }
682 
683     // This slice refers to a loop that doesn't exist in the IR yet. Add its
684     // bounds to the system assuming its dimension variable position is the
685     // same as the position of the loop in the loop nest.
686     if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands)))
687       return failure();
688     if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands)))
689       return failure();
690   }
691   return success();
692 }
693 
addAffineIfOpDomain(AffineIfOp ifOp)694 void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
695   // Create the base constraints from the integer set attached to ifOp.
696   FlatAffineValueConstraints cst(ifOp.getIntegerSet());
697 
698   // Bind vars in the constraints to ifOp operands.
699   SmallVector<Value, 4> operands = ifOp.getOperands();
700   cst.setValues(0, cst.getNumDimAndSymbolVars(), operands);
701 
702   // Merge the constraints from ifOp to the current domain. We need first merge
703   // and align the IDs from both constraints, and then append the constraints
704   // from the ifOp into the current one.
705   mergeAndAlignVarsWithOther(0, &cst);
706   append(cst);
707 }
708 
hasConsistentState() const709 bool FlatAffineValueConstraints::hasConsistentState() const {
710   return IntegerPolyhedron::hasConsistentState() &&
711          values.size() == getNumDimAndSymbolVars();
712 }
713 
removeVarRange(VarKind kind,unsigned varStart,unsigned varLimit)714 void FlatAffineValueConstraints::removeVarRange(VarKind kind, unsigned varStart,
715                                                 unsigned varLimit) {
716   IntegerPolyhedron::removeVarRange(kind, varStart, varLimit);
717   unsigned offset = getVarKindOffset(kind);
718 
719   if (kind != VarKind::Local) {
720     values.erase(values.begin() + varStart + offset,
721                  values.begin() + varLimit + offset);
722   }
723 }
724 
725 // Determine whether the variable at 'pos' (say var_r) can be expressed as
726 // modulo of another known variable (say var_n) w.r.t a constant. For example,
727 // if the following constraints hold true:
728 // ```
729 // 0 <= var_r <= divisor - 1
730 // var_n - (divisor * q_expr) = var_r
731 // ```
732 // where `var_n` is a known variable (called dividend), and `q_expr` is an
733 // `AffineExpr` (called the quotient expression), `var_r` can be written as:
734 //
735 // `var_r = var_n mod divisor`.
736 //
737 // Additionally, in a special case of the above constaints where `q_expr` is an
738 // variable itself that is not yet known (say `var_q`), it can be written as a
739 // floordiv in the following way:
740 //
741 // `var_q = var_n floordiv divisor`.
742 //
743 // Returns true if the above mod or floordiv are detected, updating 'memo' with
744 // these new expressions. Returns false otherwise.
detectAsMod(const FlatAffineValueConstraints & cst,unsigned pos,int64_t lbConst,int64_t ubConst,SmallVectorImpl<AffineExpr> & memo,MLIRContext * context)745 static bool detectAsMod(const FlatAffineValueConstraints &cst, unsigned pos,
746                         int64_t lbConst, int64_t ubConst,
747                         SmallVectorImpl<AffineExpr> &memo,
748                         MLIRContext *context) {
749   assert(pos < cst.getNumVars() && "invalid position");
750 
751   // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can
752   // be determined.
753   if (lbConst != 0 || ubConst < 1)
754     return false;
755   int64_t divisor = ubConst + 1;
756 
757   // Check for the aforementioned conditions in each equality.
758   for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
759        curEquality < numEqualities; curEquality++) {
760     int64_t coefficientAtPos = cst.atEq(curEquality, pos);
761     // If current equality does not involve `var_r`, continue to the next
762     // equality.
763     if (coefficientAtPos == 0)
764       continue;
765 
766     // Constant term should be 0 in this equality.
767     if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
768       continue;
769 
770     // Traverse through the equality and construct the dividend expression
771     // `dividendExpr`, to contain all the variables which are known and are
772     // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
773     // `dividendExpr` gets simplified into a single variable `var_n` discussed
774     // above.
775     auto dividendExpr = getAffineConstantExpr(0, context);
776 
777     // Track the terms that go into quotient expression, later used to detect
778     // additional floordiv.
779     unsigned quotientCount = 0;
780     int quotientPosition = -1;
781     int quotientSign = 1;
782 
783     // Consider each term in the current equality.
784     unsigned curVar, e;
785     for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) {
786       // Ignore var_r.
787       if (curVar == pos)
788         continue;
789       int64_t coefficientOfCurVar = cst.atEq(curEquality, curVar);
790       // Ignore vars that do not contribute to the current equality.
791       if (coefficientOfCurVar == 0)
792         continue;
793       // Check if the current var goes into the quotient expression.
794       if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) {
795         quotientCount++;
796         quotientPosition = curVar;
797         quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1;
798         continue;
799       }
800       // Variables that are part of dividendExpr should be known.
801       if (!memo[curVar])
802         break;
803       // Append the current variable to the dividend expression.
804       dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar;
805     }
806 
807     // Can't construct expression as it depends on a yet uncomputed var.
808     if (curVar < e)
809       continue;
810 
811     // Express `var_r` in terms of the other vars collected so far.
812     if (coefficientAtPos > 0)
813       dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
814     else
815       dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
816 
817     // Simplify the expression.
818     dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimVars(),
819                                       cst.getNumSymbolVars());
820     // Only if the final dividend expression is just a single var (which we call
821     // `var_n`), we can proceed.
822     // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
823     // to dims themselves.
824     auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
825     if (!dimExpr)
826       continue;
827 
828     // Express `var_r` as `var_n % divisor` and store the expression in `memo`.
829     if (quotientCount >= 1) {
830       auto ub = cst.getConstantBound(FlatAffineValueConstraints::BoundType::UB,
831                                      dimExpr.getPosition());
832       // If `var_n` has an upperbound that is less than the divisor, mod can be
833       // eliminated altogether.
834       if (ub && *ub < divisor)
835         memo[pos] = dimExpr;
836       else
837         memo[pos] = dimExpr % divisor;
838       // If a unique quotient `var_q` was seen, it can be expressed as
839       // `var_n floordiv divisor`.
840       if (quotientCount == 1 && !memo[quotientPosition])
841         memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
842 
843       return true;
844     }
845   }
846   return false;
847 }
848 
849 /// Check if the pos^th variable can be expressed as a floordiv of an affine
850 /// function of other variables (where the divisor is a positive constant)
851 /// given the initial set of expressions in `exprs`. If it can be, the
852 /// corresponding position in `exprs` is set as the detected affine expr. For
853 /// eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4. An equality can
854 /// also yield a floordiv: eg.  4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
855 /// <= i <= 32q + 31 => q = i floordiv 32.
detectAsFloorDiv(const FlatAffineValueConstraints & cst,unsigned pos,MLIRContext * context,SmallVectorImpl<AffineExpr> & exprs)856 static bool detectAsFloorDiv(const FlatAffineValueConstraints &cst,
857                              unsigned pos, MLIRContext *context,
858                              SmallVectorImpl<AffineExpr> &exprs) {
859   assert(pos < cst.getNumVars() && "invalid position");
860 
861   // Get upper-lower bound pair for this variable.
862   SmallVector<bool, 8> foundRepr(cst.getNumVars(), false);
863   for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i)
864     if (exprs[i])
865       foundRepr[i] = true;
866 
867   SmallVector<int64_t, 8> dividend(cst.getNumCols());
868   unsigned divisor;
869   auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
870 
871   // No upper-lower bound pair found for this var.
872   if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality)
873     return false;
874 
875   // Construct the dividend expression.
876   auto dividendExpr = getAffineConstantExpr(dividend.back(), context);
877   for (unsigned c = 0, f = cst.getNumVars(); c < f; c++)
878     if (dividend[c] != 0)
879       dividendExpr = dividendExpr + dividend[c] * exprs[c];
880 
881   // Successfully detected the floordiv.
882   exprs[pos] = dividendExpr.floorDiv(divisor);
883   return true;
884 }
885 
886 std::pair<AffineMap, AffineMap>
getLowerAndUpperBound(unsigned pos,unsigned offset,unsigned num,unsigned symStartPos,ArrayRef<AffineExpr> localExprs,MLIRContext * context) const887 FlatAffineValueConstraints::getLowerAndUpperBound(
888     unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
889     ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
890   assert(pos + offset < getNumDimVars() && "invalid dim start pos");
891   assert(symStartPos >= (pos + offset) && "invalid sym start pos");
892   assert(getNumLocalVars() == localExprs.size() &&
893          "incorrect local exprs count");
894 
895   SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
896   getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
897                                offset, num);
898 
899   /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
900   auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
901     b.clear();
902     for (unsigned i = 0, e = a.size(); i < e; ++i) {
903       if (i < offset || i >= offset + num)
904         b.push_back(a[i]);
905     }
906   };
907 
908   SmallVector<int64_t, 8> lb, ub;
909   SmallVector<AffineExpr, 4> lbExprs;
910   unsigned dimCount = symStartPos - num;
911   unsigned symCount = getNumDimAndSymbolVars() - symStartPos;
912   lbExprs.reserve(lbIndices.size() + eqIndices.size());
913   // Lower bound expressions.
914   for (auto idx : lbIndices) {
915     auto ineq = getInequality(idx);
916     // Extract the lower bound (in terms of other coeff's + const), i.e., if
917     // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
918     // - 1.
919     addCoeffs(ineq, lb);
920     std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
921     auto expr =
922         getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
923     // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
924     int64_t divisor = std::abs(ineq[pos + offset]);
925     expr = (expr + divisor - 1).floorDiv(divisor);
926     lbExprs.push_back(expr);
927   }
928 
929   SmallVector<AffineExpr, 4> ubExprs;
930   ubExprs.reserve(ubIndices.size() + eqIndices.size());
931   // Upper bound expressions.
932   for (auto idx : ubIndices) {
933     auto ineq = getInequality(idx);
934     // Extract the upper bound (in terms of other coeff's + const).
935     addCoeffs(ineq, ub);
936     auto expr =
937         getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
938     expr = expr.floorDiv(std::abs(ineq[pos + offset]));
939     // Upper bound is exclusive.
940     ubExprs.push_back(expr + 1);
941   }
942 
943   // Equalities. It's both a lower and a upper bound.
944   SmallVector<int64_t, 4> b;
945   for (auto idx : eqIndices) {
946     auto eq = getEquality(idx);
947     addCoeffs(eq, b);
948     if (eq[pos + offset] > 0)
949       std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
950 
951     // Extract the upper bound (in terms of other coeff's + const).
952     auto expr =
953         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
954     expr = expr.floorDiv(std::abs(eq[pos + offset]));
955     // Upper bound is exclusive.
956     ubExprs.push_back(expr + 1);
957     // Lower bound.
958     expr =
959         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
960     expr = expr.ceilDiv(std::abs(eq[pos + offset]));
961     lbExprs.push_back(expr);
962   }
963 
964   auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
965   auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
966 
967   return {lbMap, ubMap};
968 }
969 
970 /// Computes the lower and upper bounds of the first 'num' dimensional
971 /// variables (starting at 'offset') as affine maps of the remaining
972 /// variables (dimensional and symbolic variables). Local variables are
973 /// themselves explicitly computed as affine functions of other variables in
974 /// this process if needed.
getSliceBounds(unsigned offset,unsigned num,MLIRContext * context,SmallVectorImpl<AffineMap> * lbMaps,SmallVectorImpl<AffineMap> * ubMaps,bool getClosedUB)975 void FlatAffineValueConstraints::getSliceBounds(
976     unsigned offset, unsigned num, MLIRContext *context,
977     SmallVectorImpl<AffineMap> *lbMaps, SmallVectorImpl<AffineMap> *ubMaps,
978     bool getClosedUB) {
979   assert(num < getNumDimVars() && "invalid range");
980 
981   // Basic simplification.
982   normalizeConstraintsByGCD();
983 
984   LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
985                           << " variables\n");
986   LLVM_DEBUG(dump());
987 
988   // Record computed/detected variables.
989   SmallVector<AffineExpr, 8> memo(getNumVars());
990   // Initialize dimensional and symbolic variables.
991   for (unsigned i = 0, e = getNumDimVars(); i < e; i++) {
992     if (i < offset)
993       memo[i] = getAffineDimExpr(i, context);
994     else if (i >= offset + num)
995       memo[i] = getAffineDimExpr(i - num, context);
996   }
997   for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++)
998     memo[i] = getAffineSymbolExpr(i - getNumDimVars(), context);
999 
1000   bool changed;
1001   do {
1002     changed = false;
1003     // Identify yet unknown variables as constants or mod's / floordiv's of
1004     // other variables if possible.
1005     for (unsigned pos = 0; pos < getNumVars(); pos++) {
1006       if (memo[pos])
1007         continue;
1008 
1009       auto lbConst = getConstantBound(BoundType::LB, pos);
1010       auto ubConst = getConstantBound(BoundType::UB, pos);
1011       if (lbConst.has_value() && ubConst.has_value()) {
1012         // Detect equality to a constant.
1013         if (lbConst.value() == ubConst.value()) {
1014           memo[pos] = getAffineConstantExpr(lbConst.value(), context);
1015           changed = true;
1016           continue;
1017         }
1018 
1019         // Detect an variable as modulo of another variable w.r.t a
1020         // constant.
1021         if (detectAsMod(*this, pos, lbConst.value(), ubConst.value(), memo,
1022                         context)) {
1023           changed = true;
1024           continue;
1025         }
1026       }
1027 
1028       // Detect an variable as a floordiv of an affine function of other
1029       // variables (divisor is a positive constant).
1030       if (detectAsFloorDiv(*this, pos, context, memo)) {
1031         changed = true;
1032         continue;
1033       }
1034 
1035       // Detect an variable as an expression of other variables.
1036       unsigned idx;
1037       if (!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) {
1038         continue;
1039       }
1040 
1041       // Build AffineExpr solving for variable 'pos' in terms of all others.
1042       auto expr = getAffineConstantExpr(0, context);
1043       unsigned j, e;
1044       for (j = 0, e = getNumVars(); j < e; ++j) {
1045         if (j == pos)
1046           continue;
1047         int64_t c = atEq(idx, j);
1048         if (c == 0)
1049           continue;
1050         // If any of the involved IDs hasn't been found yet, we can't proceed.
1051         if (!memo[j])
1052           break;
1053         expr = expr + memo[j] * c;
1054       }
1055       if (j < e)
1056         // Can't construct expression as it depends on a yet uncomputed
1057         // variable.
1058         continue;
1059 
1060       // Add constant term to AffineExpr.
1061       expr = expr + atEq(idx, getNumVars());
1062       int64_t vPos = atEq(idx, pos);
1063       assert(vPos != 0 && "expected non-zero here");
1064       if (vPos > 0)
1065         expr = (-expr).floorDiv(vPos);
1066       else
1067         // vPos < 0.
1068         expr = expr.floorDiv(-vPos);
1069       // Successfully constructed expression.
1070       memo[pos] = expr;
1071       changed = true;
1072     }
1073     // This loop is guaranteed to reach a fixed point - since once an
1074     // variable's explicit form is computed (in memo[pos]), it's not updated
1075     // again.
1076   } while (changed);
1077 
1078   int64_t ubAdjustment = getClosedUB ? 0 : 1;
1079 
1080   // Set the lower and upper bound maps for all the variables that were
1081   // computed as affine expressions of the rest as the "detected expr" and
1082   // "detected expr + 1" respectively; set the undetected ones to null.
1083   Optional<FlatAffineValueConstraints> tmpClone;
1084   for (unsigned pos = 0; pos < num; pos++) {
1085     unsigned numMapDims = getNumDimVars() - num;
1086     unsigned numMapSymbols = getNumSymbolVars();
1087     AffineExpr expr = memo[pos + offset];
1088     if (expr)
1089       expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
1090 
1091     AffineMap &lbMap = (*lbMaps)[pos];
1092     AffineMap &ubMap = (*ubMaps)[pos];
1093 
1094     if (expr) {
1095       lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
1096       ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + ubAdjustment);
1097     } else {
1098       // TODO: Whenever there are local variables in the dependence
1099       // constraints, we'll conservatively over-approximate, since we don't
1100       // always explicitly compute them above (in the while loop).
1101       if (getNumLocalVars() == 0) {
1102         // Work on a copy so that we don't update this constraint system.
1103         if (!tmpClone) {
1104           tmpClone.emplace(FlatAffineValueConstraints(*this));
1105           // Removing redundant inequalities is necessary so that we don't get
1106           // redundant loop bounds.
1107           tmpClone->removeRedundantInequalities();
1108         }
1109         std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
1110             pos, offset, num, getNumDimVars(), /*localExprs=*/{}, context);
1111       }
1112 
1113       // If the above fails, we'll just use the constant lower bound and the
1114       // constant upper bound (if they exist) as the slice bounds.
1115       // TODO: being conservative for the moment in cases that
1116       // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
1117       // fixed (b/126426796).
1118       if (!lbMap || lbMap.getNumResults() > 1) {
1119         LLVM_DEBUG(llvm::dbgs()
1120                    << "WARNING: Potentially over-approximating slice lb\n");
1121         auto lbConst = getConstantBound(BoundType::LB, pos + offset);
1122         if (lbConst.has_value()) {
1123           lbMap =
1124               AffineMap::get(numMapDims, numMapSymbols,
1125                              getAffineConstantExpr(lbConst.value(), context));
1126         }
1127       }
1128       if (!ubMap || ubMap.getNumResults() > 1) {
1129         LLVM_DEBUG(llvm::dbgs()
1130                    << "WARNING: Potentially over-approximating slice ub\n");
1131         auto ubConst = getConstantBound(BoundType::UB, pos + offset);
1132         if (ubConst.has_value()) {
1133           ubMap = AffineMap::get(
1134               numMapDims, numMapSymbols,
1135               getAffineConstantExpr(ubConst.value() + ubAdjustment, context));
1136         }
1137       }
1138     }
1139     LLVM_DEBUG(llvm::dbgs()
1140                << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
1141     LLVM_DEBUG(lbMap.dump(););
1142     LLVM_DEBUG(llvm::dbgs()
1143                << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
1144     LLVM_DEBUG(ubMap.dump(););
1145   }
1146 }
1147 
flattenAlignedMapAndMergeLocals(AffineMap map,std::vector<SmallVector<int64_t,8>> * flattenedExprs)1148 LogicalResult FlatAffineValueConstraints::flattenAlignedMapAndMergeLocals(
1149     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
1150   FlatAffineValueConstraints localCst;
1151   if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) {
1152     LLVM_DEBUG(llvm::dbgs()
1153                << "composition unimplemented for semi-affine maps\n");
1154     return failure();
1155   }
1156 
1157   // Add localCst information.
1158   if (localCst.getNumLocalVars() > 0) {
1159     unsigned numLocalVars = getNumLocalVars();
1160     // Insert local dims of localCst at the beginning.
1161     insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars());
1162     // Insert local dims of `this` at the end of localCst.
1163     localCst.appendLocalVar(/*num=*/numLocalVars);
1164     // Dimensions of localCst and this constraint set match. Append localCst to
1165     // this constraint set.
1166     append(localCst);
1167   }
1168 
1169   return success();
1170 }
1171 
addBound(BoundType type,unsigned pos,AffineMap boundMap,bool isClosedBound)1172 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
1173                                                    AffineMap boundMap,
1174                                                    bool isClosedBound) {
1175   assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch");
1176   assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
1177   assert(pos < getNumDimAndSymbolVars() && "invalid position");
1178   assert((type != BoundType::EQ || isClosedBound) &&
1179          "EQ bound must be closed.");
1180 
1181   // Equality follows the logic of lower bound except that we add an equality
1182   // instead of an inequality.
1183   assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
1184          "single result expected");
1185   bool lower = type == BoundType::LB || type == BoundType::EQ;
1186 
1187   std::vector<SmallVector<int64_t, 8>> flatExprs;
1188   if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
1189     return failure();
1190   assert(flatExprs.size() == boundMap.getNumResults());
1191 
1192   // Add one (in)equality for each result.
1193   for (const auto &flatExpr : flatExprs) {
1194     SmallVector<int64_t> ineq(getNumCols(), 0);
1195     // Dims and symbols.
1196     for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
1197       ineq[j] = lower ? -flatExpr[j] : flatExpr[j];
1198     }
1199     // Invalid bound: pos appears in `boundMap`.
1200     // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or
1201     // its callers to prevent invalid bounds from being added.
1202     if (ineq[pos] != 0)
1203       continue;
1204     ineq[pos] = lower ? 1 : -1;
1205     // Local columns of `ineq` are at the beginning.
1206     unsigned j = getNumDimVars() + getNumSymbolVars();
1207     unsigned end = flatExpr.size() - 1;
1208     for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
1209       ineq[j] = lower ? -flatExpr[i] : flatExpr[i];
1210     }
1211     // Make the bound closed in if flatExpr is open. The inequality is always
1212     // created in the upper bound form, so the adjustment is -1.
1213     int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1;
1214     // Constant term.
1215     ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1]
1216                                     : flatExpr[flatExpr.size() - 1]) +
1217                              boundAdjustment;
1218     type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq);
1219   }
1220 
1221   return success();
1222 }
1223 
addBound(BoundType type,unsigned pos,AffineMap boundMap)1224 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
1225                                                    AffineMap boundMap) {
1226   return addBound(type, pos, boundMap, /*isClosedBound=*/type != BoundType::UB);
1227 }
1228 
1229 AffineMap
computeAlignedMap(AffineMap map,ValueRange operands) const1230 FlatAffineValueConstraints::computeAlignedMap(AffineMap map,
1231                                               ValueRange operands) const {
1232   assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
1233 
1234   SmallVector<Value> dims, syms;
1235 #ifndef NDEBUG
1236   SmallVector<Value> newSyms;
1237   SmallVector<Value> *newSymsPtr = &newSyms;
1238 #else
1239   SmallVector<Value> *newSymsPtr = nullptr;
1240 #endif // NDEBUG
1241 
1242   dims.reserve(getNumDimVars());
1243   syms.reserve(getNumSymbolVars());
1244   for (unsigned i = getVarKindOffset(VarKind::SetDim),
1245                 e = getVarKindEnd(VarKind::SetDim);
1246        i < e; ++i)
1247     dims.push_back(values[i] ? *values[i] : Value());
1248   for (unsigned i = getVarKindOffset(VarKind::Symbol),
1249                 e = getVarKindEnd(VarKind::Symbol);
1250        i < e; ++i)
1251     syms.push_back(values[i] ? *values[i] : Value());
1252 
1253   AffineMap alignedMap =
1254       alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
1255   // All symbols are already part of this FlatAffineConstraints.
1256   assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
1257   assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
1258          "unexpected new/missing symbols");
1259   return alignedMap;
1260 }
1261 
addBound(BoundType type,unsigned pos,AffineMap boundMap,ValueRange boundOperands)1262 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
1263                                                    AffineMap boundMap,
1264                                                    ValueRange boundOperands) {
1265   // Fully compose map and operands; canonicalize and simplify so that we
1266   // transitively get to terminal symbols or loop IVs.
1267   auto map = boundMap;
1268   SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
1269   fullyComposeAffineMapAndOperands(&map, &operands);
1270   map = simplifyAffineMap(map);
1271   canonicalizeMapAndOperands(&map, &operands);
1272   for (auto operand : operands)
1273     addInductionVarOrTerminalSymbol(operand);
1274   return addBound(type, pos, computeAlignedMap(map, operands));
1275 }
1276 
1277 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
1278 // bounds in 'ubMaps' to each value in `values' that appears in the constraint
1279 // system. Note that both lower/upper bounds share the same operand list
1280 // 'operands'.
1281 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
1282 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
1283 // Note that both lower/upper bounds use operands from 'operands'.
1284 // Returns failure for unimplemented cases such as semi-affine expressions or
1285 // expressions with mod/floordiv.
addSliceBounds(ArrayRef<Value> values,ArrayRef<AffineMap> lbMaps,ArrayRef<AffineMap> ubMaps,ArrayRef<Value> operands)1286 LogicalResult FlatAffineValueConstraints::addSliceBounds(
1287     ArrayRef<Value> values, ArrayRef<AffineMap> lbMaps,
1288     ArrayRef<AffineMap> ubMaps, ArrayRef<Value> operands) {
1289   assert(values.size() == lbMaps.size());
1290   assert(lbMaps.size() == ubMaps.size());
1291 
1292   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
1293     unsigned pos;
1294     if (!findVar(values[i], &pos))
1295       continue;
1296 
1297     AffineMap lbMap = lbMaps[i];
1298     AffineMap ubMap = ubMaps[i];
1299     assert(!lbMap || lbMap.getNumInputs() == operands.size());
1300     assert(!ubMap || ubMap.getNumInputs() == operands.size());
1301 
1302     // Check if this slice is just an equality along this dimension.
1303     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
1304         ubMap.getNumResults() == 1 &&
1305         lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
1306       if (failed(addBound(BoundType::EQ, pos, lbMap, operands)))
1307         return failure();
1308       continue;
1309     }
1310 
1311     // If lower or upper bound maps are null or provide no results, it implies
1312     // that the source loop was not at all sliced, and the entire loop will be a
1313     // part of the slice.
1314     if (lbMap && lbMap.getNumResults() != 0 && ubMap &&
1315         ubMap.getNumResults() != 0) {
1316       if (failed(addBound(BoundType::LB, pos, lbMap, operands)))
1317         return failure();
1318       if (failed(addBound(BoundType::UB, pos, ubMap, operands)))
1319         return failure();
1320     } else {
1321       auto loop = getForInductionVarOwner(values[i]);
1322       if (failed(this->addAffineForOpDomain(loop)))
1323         return failure();
1324     }
1325   }
1326   return success();
1327 }
1328 
findVar(Value val,unsigned * pos) const1329 bool FlatAffineValueConstraints::findVar(Value val, unsigned *pos) const {
1330   unsigned i = 0;
1331   for (const auto &mayBeVar : values) {
1332     if (mayBeVar && *mayBeVar == val) {
1333       *pos = i;
1334       return true;
1335     }
1336     i++;
1337   }
1338   return false;
1339 }
1340 
containsVar(Value val) const1341 bool FlatAffineValueConstraints::containsVar(Value val) const {
1342   return llvm::any_of(values, [&](const Optional<Value> &mayBeVar) {
1343     return mayBeVar && *mayBeVar == val;
1344   });
1345 }
1346 
swapVar(unsigned posA,unsigned posB)1347 void FlatAffineValueConstraints::swapVar(unsigned posA, unsigned posB) {
1348   IntegerPolyhedron::swapVar(posA, posB);
1349 
1350   if (getVarKindAt(posA) == VarKind::Local &&
1351       getVarKindAt(posB) == VarKind::Local)
1352     return;
1353 
1354   // Treat value of a local variable as None.
1355   if (getVarKindAt(posA) == VarKind::Local)
1356     values[posB] = None;
1357   else if (getVarKindAt(posB) == VarKind::Local)
1358     values[posA] = None;
1359   else
1360     std::swap(values[posA], values[posB]);
1361 }
1362 
addBound(BoundType type,Value val,int64_t value)1363 void FlatAffineValueConstraints::addBound(BoundType type, Value val,
1364                                           int64_t value) {
1365   unsigned pos;
1366   if (!findVar(val, &pos))
1367     // This is a pre-condition for this method.
1368     assert(0 && "var not found");
1369   addBound(type, pos, value);
1370 }
1371 
printSpace(raw_ostream & os) const1372 void FlatAffineValueConstraints::printSpace(raw_ostream &os) const {
1373   IntegerPolyhedron::printSpace(os);
1374   os << "(";
1375   for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) {
1376     if (hasValue(i))
1377       os << "Value ";
1378     else
1379       os << "None ";
1380   }
1381   for (unsigned i = getVarKindOffset(VarKind::Local),
1382                 e = getVarKindEnd(VarKind::Local);
1383        i < e; ++i)
1384     os << "Local ";
1385   os << " const)\n";
1386 }
1387 
clearAndCopyFrom(const IntegerRelation & other)1388 void FlatAffineValueConstraints::clearAndCopyFrom(
1389     const IntegerRelation &other) {
1390 
1391   if (auto *otherValueSet =
1392           dyn_cast<const FlatAffineValueConstraints>(&other)) {
1393     *this = *otherValueSet;
1394   } else {
1395     *static_cast<IntegerRelation *>(this) = other;
1396     values.clear();
1397     values.resize(getNumDimAndSymbolVars(), None);
1398   }
1399 }
1400 
fourierMotzkinEliminate(unsigned pos,bool darkShadow,bool * isResultIntegerExact)1401 void FlatAffineValueConstraints::fourierMotzkinEliminate(
1402     unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
1403   SmallVector<Optional<Value>, 8> newVals = values;
1404   if (getVarKindAt(pos) != VarKind::Local)
1405     newVals.erase(newVals.begin() + pos);
1406   // Note: Base implementation discards all associated Values.
1407   IntegerPolyhedron::fourierMotzkinEliminate(pos, darkShadow,
1408                                              isResultIntegerExact);
1409   values = newVals;
1410   assert(values.size() == getNumDimAndSymbolVars());
1411 }
1412 
projectOut(Value val)1413 void FlatAffineValueConstraints::projectOut(Value val) {
1414   unsigned pos;
1415   bool ret = findVar(val, &pos);
1416   assert(ret);
1417   (void)ret;
1418   fourierMotzkinEliminate(pos);
1419 }
1420 
unionBoundingBox(const FlatAffineValueConstraints & otherCst)1421 LogicalResult FlatAffineValueConstraints::unionBoundingBox(
1422     const FlatAffineValueConstraints &otherCst) {
1423   assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch");
1424   assert(otherCst.getMaybeValues()
1425              .slice(0, getNumDimVars())
1426              .equals(getMaybeValues().slice(0, getNumDimVars())) &&
1427          "dim values mismatch");
1428   assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
1429   assert(getNumLocalVars() == 0 && "local vars not supported yet here");
1430 
1431   // Align `other` to this.
1432   if (!areVarsAligned(*this, otherCst)) {
1433     FlatAffineValueConstraints otherCopy(otherCst);
1434     mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy);
1435     return IntegerPolyhedron::unionBoundingBox(otherCopy);
1436   }
1437 
1438   return IntegerPolyhedron::unionBoundingBox(otherCst);
1439 }
1440 
1441 /// Compute an explicit representation for local vars. For all systems coming
1442 /// from MLIR integer sets, maps, or expressions where local vars were
1443 /// introduced to model floordivs and mods, this always succeeds.
computeLocalVars(const FlatAffineValueConstraints & cst,SmallVectorImpl<AffineExpr> & memo,MLIRContext * context)1444 static LogicalResult computeLocalVars(const FlatAffineValueConstraints &cst,
1445                                       SmallVectorImpl<AffineExpr> &memo,
1446                                       MLIRContext *context) {
1447   unsigned numDims = cst.getNumDimVars();
1448   unsigned numSyms = cst.getNumSymbolVars();
1449 
1450   // Initialize dimensional and symbolic variables.
1451   for (unsigned i = 0; i < numDims; i++)
1452     memo[i] = getAffineDimExpr(i, context);
1453   for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
1454     memo[i] = getAffineSymbolExpr(i - numDims, context);
1455 
1456   bool changed;
1457   do {
1458     // Each time `changed` is true at the end of this iteration, one or more
1459     // local vars would have been detected as floordivs and set in memo; so the
1460     // number of null entries in memo[...] strictly reduces; so this converges.
1461     changed = false;
1462     for (unsigned i = 0, e = cst.getNumLocalVars(); i < e; ++i)
1463       if (!memo[numDims + numSyms + i] &&
1464           detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
1465         changed = true;
1466   } while (changed);
1467 
1468   ArrayRef<AffineExpr> localExprs =
1469       ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalVars());
1470   return success(
1471       llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
1472 }
1473 
getIneqAsAffineValueMap(unsigned pos,unsigned ineqPos,AffineValueMap & vmap,MLIRContext * context) const1474 void FlatAffineValueConstraints::getIneqAsAffineValueMap(
1475     unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
1476     MLIRContext *context) const {
1477   unsigned numDims = getNumDimVars();
1478   unsigned numSyms = getNumSymbolVars();
1479 
1480   assert(pos < numDims && "invalid position");
1481   assert(ineqPos < getNumInequalities() && "invalid inequality position");
1482 
1483   // Get expressions for local vars.
1484   SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
1485   if (failed(computeLocalVars(*this, memo, context)))
1486     assert(false &&
1487            "one or more local exprs do not have an explicit representation");
1488   auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars());
1489 
1490   // Compute the AffineExpr lower/upper bound for this inequality.
1491   ArrayRef<int64_t> inequality = getInequality(ineqPos);
1492   SmallVector<int64_t, 8> bound;
1493   bound.reserve(getNumCols() - 1);
1494   // Everything other than the coefficient at `pos`.
1495   bound.append(inequality.begin(), inequality.begin() + pos);
1496   bound.append(inequality.begin() + pos + 1, inequality.end());
1497 
1498   if (inequality[pos] > 0)
1499     // Lower bound.
1500     std::transform(bound.begin(), bound.end(), bound.begin(),
1501                    std::negate<int64_t>());
1502   else
1503     // Upper bound (which is exclusive).
1504     bound.back() += 1;
1505 
1506   // Convert to AffineExpr (tree) form.
1507   auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
1508                                              localExprs, context);
1509 
1510   // Get the values to bind to this affine expr (all dims and symbols).
1511   SmallVector<Value, 4> operands;
1512   getValues(0, pos, &operands);
1513   SmallVector<Value, 4> trailingOperands;
1514   getValues(pos + 1, getNumDimAndSymbolVars(), &trailingOperands);
1515   operands.append(trailingOperands.begin(), trailingOperands.end());
1516   vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
1517 }
1518 
1519 IntegerSet
getAsIntegerSet(MLIRContext * context) const1520 FlatAffineValueConstraints::getAsIntegerSet(MLIRContext *context) const {
1521   if (getNumConstraints() == 0)
1522     // Return universal set (always true): 0 == 0.
1523     return IntegerSet::get(getNumDimVars(), getNumSymbolVars(),
1524                            getAffineConstantExpr(/*constant=*/0, context),
1525                            /*eqFlags=*/true);
1526 
1527   // Construct local references.
1528   SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
1529 
1530   if (failed(computeLocalVars(*this, memo, context))) {
1531     // Check if the local variables without an explicit representation have
1532     // zero coefficients everywhere.
1533     SmallVector<unsigned> noLocalRepVars;
1534     unsigned numDimsSymbols = getNumDimAndSymbolVars();
1535     for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) {
1536       if (!memo[i] && !isColZero(/*pos=*/i))
1537         noLocalRepVars.push_back(i - numDimsSymbols);
1538     }
1539     if (!noLocalRepVars.empty()) {
1540       LLVM_DEBUG({
1541         llvm::dbgs() << "local variables at position(s) ";
1542         llvm::interleaveComma(noLocalRepVars, llvm::dbgs());
1543         llvm::dbgs() << " do not have an explicit representation in:\n";
1544         this->dump();
1545       });
1546       return IntegerSet();
1547     }
1548   }
1549 
1550   ArrayRef<AffineExpr> localExprs =
1551       ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars());
1552 
1553   // Construct the IntegerSet from the equalities/inequalities.
1554   unsigned numDims = getNumDimVars();
1555   unsigned numSyms = getNumSymbolVars();
1556 
1557   SmallVector<bool, 16> eqFlags(getNumConstraints());
1558   std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
1559   std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
1560 
1561   SmallVector<AffineExpr, 8> exprs;
1562   exprs.reserve(getNumConstraints());
1563 
1564   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1565     exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
1566                                               localExprs, context));
1567   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
1568     exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
1569                                               numSyms, localExprs, context));
1570   return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
1571 }
1572 
alignAffineMapWithValues(AffineMap map,ValueRange operands,ValueRange dims,ValueRange syms,SmallVector<Value> * newSyms)1573 AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
1574                                          ValueRange dims, ValueRange syms,
1575                                          SmallVector<Value> *newSyms) {
1576   assert(operands.size() == map.getNumInputs() &&
1577          "expected same number of operands and map inputs");
1578   MLIRContext *ctx = map.getContext();
1579   Builder builder(ctx);
1580   SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
1581   unsigned numSymbols = syms.size();
1582   SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
1583   if (newSyms) {
1584     newSyms->clear();
1585     newSyms->append(syms.begin(), syms.end());
1586   }
1587 
1588   for (const auto &operand : llvm::enumerate(operands)) {
1589     // Compute replacement dim/sym of operand.
1590     AffineExpr replacement;
1591     auto dimIt = std::find(dims.begin(), dims.end(), operand.value());
1592     auto symIt = std::find(syms.begin(), syms.end(), operand.value());
1593     if (dimIt != dims.end()) {
1594       replacement =
1595           builder.getAffineDimExpr(std::distance(dims.begin(), dimIt));
1596     } else if (symIt != syms.end()) {
1597       replacement =
1598           builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt));
1599     } else {
1600       // This operand is neither a dimension nor a symbol. Add it as a new
1601       // symbol.
1602       replacement = builder.getAffineSymbolExpr(numSymbols++);
1603       if (newSyms)
1604         newSyms->push_back(operand.value());
1605     }
1606     // Add to corresponding replacements vector.
1607     if (operand.index() < map.getNumDims()) {
1608       dimReplacements[operand.index()] = replacement;
1609     } else {
1610       symReplacements[operand.index() - map.getNumDims()] = replacement;
1611     }
1612   }
1613 
1614   return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
1615                                    dims.size(), numSymbols);
1616 }
1617 
getDomainSet() const1618 FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const {
1619   FlatAffineValueConstraints domain = *this;
1620   // Convert all range variables to local variables.
1621   domain.convertToLocal(VarKind::SetDim, getNumDomainDims(),
1622                         getNumDomainDims() + getNumRangeDims());
1623   return domain;
1624 }
1625 
getRangeSet() const1626 FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const {
1627   FlatAffineValueConstraints range = *this;
1628   // Convert all domain variables to local variables.
1629   range.convertToLocal(VarKind::SetDim, 0, getNumDomainDims());
1630   return range;
1631 }
1632 
compose(const FlatAffineRelation & other)1633 void FlatAffineRelation::compose(const FlatAffineRelation &other) {
1634   assert(getNumDomainDims() == other.getNumRangeDims() &&
1635          "Domain of this and range of other do not match");
1636   assert(std::equal(values.begin(), values.begin() + getNumDomainDims(),
1637                     other.values.begin() + other.getNumDomainDims()) &&
1638          "Domain of this and range of other do not match");
1639 
1640   FlatAffineRelation rel = other;
1641 
1642   // Convert `rel` from
1643   //    [otherDomain] -> [otherRange]
1644   // to
1645   //    [otherDomain] -> [otherRange thisRange]
1646   // and `this` from
1647   //    [thisDomain] -> [thisRange]
1648   // to
1649   //    [otherDomain thisDomain] -> [thisRange].
1650   unsigned removeDims = rel.getNumRangeDims();
1651   insertDomainVar(0, rel.getNumDomainDims());
1652   rel.appendRangeVar(getNumRangeDims());
1653 
1654   // Merge symbol and local variables.
1655   mergeSymbolVars(rel);
1656   mergeLocalVars(rel);
1657 
1658   // Convert `rel` from [otherDomain] -> [otherRange thisRange] to
1659   // [otherDomain] -> [thisRange] by converting first otherRange range vars
1660   // to local vars.
1661   rel.convertToLocal(VarKind::SetDim, rel.getNumDomainDims(),
1662                      rel.getNumDomainDims() + removeDims);
1663   // Convert `this` from [otherDomain thisDomain] -> [thisRange] to
1664   // [otherDomain] -> [thisRange] by converting last thisDomain domain vars
1665   // to local vars.
1666   convertToLocal(VarKind::SetDim, getNumDomainDims() - removeDims,
1667                  getNumDomainDims());
1668 
1669   auto thisMaybeValues = getMaybeValues(VarKind::SetDim);
1670   auto relMaybeValues = rel.getMaybeValues(VarKind::SetDim);
1671 
1672   // Add and match domain of `rel` to domain of `this`.
1673   for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
1674     if (relMaybeValues[i].has_value())
1675       setValue(i, relMaybeValues[i].value());
1676   // Add and match range of `this` to range of `rel`.
1677   for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) {
1678     unsigned rangeIdx = rel.getNumDomainDims() + i;
1679     if (thisMaybeValues[rangeIdx].has_value())
1680       rel.setValue(rangeIdx, thisMaybeValues[rangeIdx].value());
1681   }
1682 
1683   // Append `this` to `rel` and simplify constraints.
1684   rel.append(*this);
1685   rel.removeRedundantLocalVars();
1686 
1687   *this = rel;
1688 }
1689 
inverse()1690 void FlatAffineRelation::inverse() {
1691   unsigned oldDomain = getNumDomainDims();
1692   unsigned oldRange = getNumRangeDims();
1693   // Add new range vars.
1694   appendRangeVar(oldDomain);
1695   // Swap new vars with domain.
1696   for (unsigned i = 0; i < oldDomain; ++i)
1697     swapVar(i, oldDomain + oldRange + i);
1698   // Remove the swapped domain.
1699   removeVarRange(0, oldDomain);
1700   // Set domain and range as inverse.
1701   numDomainDims = oldRange;
1702   numRangeDims = oldDomain;
1703 }
1704 
insertDomainVar(unsigned pos,unsigned num)1705 void FlatAffineRelation::insertDomainVar(unsigned pos, unsigned num) {
1706   assert(pos <= getNumDomainDims() &&
1707          "Var cannot be inserted at invalid position");
1708   insertDimVar(pos, num);
1709   numDomainDims += num;
1710 }
1711 
insertRangeVar(unsigned pos,unsigned num)1712 void FlatAffineRelation::insertRangeVar(unsigned pos, unsigned num) {
1713   assert(pos <= getNumRangeDims() &&
1714          "Var cannot be inserted at invalid position");
1715   insertDimVar(getNumDomainDims() + pos, num);
1716   numRangeDims += num;
1717 }
1718 
appendDomainVar(unsigned num)1719 void FlatAffineRelation::appendDomainVar(unsigned num) {
1720   insertDimVar(getNumDomainDims(), num);
1721   numDomainDims += num;
1722 }
1723 
appendRangeVar(unsigned num)1724 void FlatAffineRelation::appendRangeVar(unsigned num) {
1725   insertDimVar(getNumDimVars(), num);
1726   numRangeDims += num;
1727 }
1728 
removeVarRange(VarKind kind,unsigned varStart,unsigned varLimit)1729 void FlatAffineRelation::removeVarRange(VarKind kind, unsigned varStart,
1730                                         unsigned varLimit) {
1731   assert(varLimit <= getNumVarKind(kind));
1732   if (varStart >= varLimit)
1733     return;
1734 
1735   FlatAffineValueConstraints::removeVarRange(kind, varStart, varLimit);
1736 
1737   // If kind is not SetDim, domain and range don't need to be updated.
1738   if (kind != VarKind::SetDim)
1739     return;
1740 
1741   // Compute number of domain and range variables to remove. This is done by
1742   // intersecting the range of domain/range vars with range of vars to remove.
1743   unsigned intersectDomainLHS = std::min(varLimit, getNumDomainDims());
1744   unsigned intersectDomainRHS = varStart;
1745   unsigned intersectRangeLHS = std::min(varLimit, getNumDimVars());
1746   unsigned intersectRangeRHS = std::max(varStart, getNumDomainDims());
1747 
1748   if (intersectDomainLHS > intersectDomainRHS)
1749     numDomainDims -= intersectDomainLHS - intersectDomainRHS;
1750   if (intersectRangeLHS > intersectRangeRHS)
1751     numRangeDims -= intersectRangeLHS - intersectRangeRHS;
1752 }
1753 
getRelationFromMap(AffineMap & map,FlatAffineRelation & rel)1754 LogicalResult mlir::getRelationFromMap(AffineMap &map,
1755                                        FlatAffineRelation &rel) {
1756   // Get flattened affine expressions.
1757   std::vector<SmallVector<int64_t, 8>> flatExprs;
1758   FlatAffineValueConstraints localVarCst;
1759   if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst)))
1760     return failure();
1761 
1762   unsigned oldDimNum = localVarCst.getNumDimVars();
1763   unsigned oldCols = localVarCst.getNumCols();
1764   unsigned numRangeVars = map.getNumResults();
1765   unsigned numDomainVars = map.getNumDims();
1766 
1767   // Add range as the new expressions.
1768   localVarCst.appendDimVar(numRangeVars);
1769 
1770   // Add equalities between source and range.
1771   SmallVector<int64_t, 8> eq(localVarCst.getNumCols());
1772   for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1773     // Zero fill.
1774     std::fill(eq.begin(), eq.end(), 0);
1775     // Fill equality.
1776     for (unsigned j = 0, f = oldDimNum; j < f; ++j)
1777       eq[j] = flatExprs[i][j];
1778     for (unsigned j = oldDimNum, f = oldCols; j < f; ++j)
1779       eq[j + numRangeVars] = flatExprs[i][j];
1780     // Set this dimension to -1 to equate lhs and rhs and add equality.
1781     eq[numDomainVars + i] = -1;
1782     localVarCst.addEquality(eq);
1783   }
1784 
1785   // Create relation and return success.
1786   rel = FlatAffineRelation(numDomainVars, numRangeVars, localVarCst);
1787   return success();
1788 }
1789 
getRelationFromMap(const AffineValueMap & map,FlatAffineRelation & rel)1790 LogicalResult mlir::getRelationFromMap(const AffineValueMap &map,
1791                                        FlatAffineRelation &rel) {
1792 
1793   AffineMap affineMap = map.getAffineMap();
1794   if (failed(getRelationFromMap(affineMap, rel)))
1795     return failure();
1796 
1797   // Set symbol values for domain dimensions and symbols.
1798   for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
1799     rel.setValue(i, map.getOperand(i));
1800   for (unsigned i = rel.getNumDimVars(), e = rel.getNumDimAndSymbolVars();
1801        i < e; ++i)
1802     rel.setValue(i, map.getOperand(i - rel.getNumRangeDims()));
1803 
1804   return success();
1805 }
1806