1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Structures for affine/polyhedral analysis of affine dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
14 #include "mlir/Analysis/Presburger/LinearTransform.h"
15 #include "mlir/Analysis/Presburger/Simplex.h"
16 #include "mlir/Analysis/Presburger/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
20 #include "mlir/IR/AffineExprVisitor.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/Support/LLVM.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 #define DEBUG_TYPE "affine-structures"
31 
32 using namespace mlir;
33 using namespace presburger;
34 
35 namespace {
36 
37 // See comments for SimpleAffineExprFlattener.
38 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
39 // constraint information associated with mod's, floordiv's, and ceildiv's
40 // in FlatAffineValueConstraints 'localVarCst'.
41 struct AffineExprFlattener : public SimpleAffineExprFlattener {
42 public:
43   // Constraints connecting newly introduced local variables (for mod's and
44   // div's) to existing (dimensional and symbolic) ones. These are always
45   // inequalities.
46   IntegerPolyhedron localVarCst;
47 
48   AffineExprFlattener(unsigned nDims, unsigned nSymbols)
49       : SimpleAffineExprFlattener(nDims, nSymbols),
50         localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {}
51 
52 private:
53   // Add a local variable (needed to flatten a mod, floordiv, ceildiv expr).
54   // The local variable added is always a floordiv of a pure add/mul affine
55   // function of other variables, coefficients of which are specified in
56   // `dividend' and with respect to the positive constant `divisor'. localExpr
57   // is the simplified tree expression (AffineExpr) corresponding to the
58   // quantifier.
59   void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
60                           AffineExpr localExpr) override {
61     SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
62     // Update localVarCst.
63     localVarCst.addLocalFloorDiv(dividend, divisor);
64   }
65 };
66 
67 } // namespace
68 
69 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
70 // flattened (i.e., semi-affine expressions not handled yet).
71 static LogicalResult
72 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
73                         unsigned numSymbols,
74                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
75                         FlatAffineValueConstraints *localVarCst) {
76   if (exprs.empty()) {
77     localVarCst->reset(numDims, numSymbols);
78     return success();
79   }
80 
81   AffineExprFlattener flattener(numDims, numSymbols);
82   // Use the same flattener to simplify each expression successively. This way
83   // local variables / expressions are shared.
84   for (auto expr : exprs) {
85     if (!expr.isPureAffine())
86       return failure();
87 
88     flattener.walkPostOrder(expr);
89   }
90 
91   assert(flattener.operandExprStack.size() == exprs.size());
92   flattenedExprs->clear();
93   flattenedExprs->assign(flattener.operandExprStack.begin(),
94                          flattener.operandExprStack.end());
95 
96   if (localVarCst)
97     localVarCst->clearAndCopyFrom(flattener.localVarCst);
98 
99   return success();
100 }
101 
102 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
103 // be flattened (semi-affine expressions not handled yet).
104 LogicalResult
105 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
106                              unsigned numSymbols,
107                              SmallVectorImpl<int64_t> *flattenedExpr,
108                              FlatAffineValueConstraints *localVarCst) {
109   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
110   LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
111                                                 &flattenedExprs, localVarCst);
112   *flattenedExpr = flattenedExprs[0];
113   return ret;
114 }
115 
116 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
117 /// flattened (i.e., semi-affine expressions not handled yet).
118 LogicalResult mlir::getFlattenedAffineExprs(
119     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
120     FlatAffineValueConstraints *localVarCst) {
121   if (map.getNumResults() == 0) {
122     localVarCst->reset(map.getNumDims(), map.getNumSymbols());
123     return success();
124   }
125   return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
126                                    map.getNumSymbols(), flattenedExprs,
127                                    localVarCst);
128 }
129 
130 LogicalResult mlir::getFlattenedAffineExprs(
131     IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
132     FlatAffineValueConstraints *localVarCst) {
133   if (set.getNumConstraints() == 0) {
134     localVarCst->reset(set.getNumDims(), set.getNumSymbols());
135     return success();
136   }
137   return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
138                                    set.getNumSymbols(), flattenedExprs,
139                                    localVarCst);
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // FlatAffineConstraints / FlatAffineValueConstraints.
144 //===----------------------------------------------------------------------===//
145 
146 std::unique_ptr<FlatAffineValueConstraints>
147 FlatAffineValueConstraints::clone() const {
148   return std::make_unique<FlatAffineValueConstraints>(*this);
149 }
150 
151 // Construct from an IntegerSet.
152 FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set)
153     : IntegerPolyhedron(set.getNumInequalities(), set.getNumEqualities(),
154                         set.getNumDims() + set.getNumSymbols() + 1,
155                         PresburgerSpace::getSetSpace(set.getNumDims(),
156                                                      set.getNumSymbols(),
157                                                      /*numLocals=*/0)) {
158 
159   // Resize values.
160   values.resize(getNumDimAndSymbolVars(), None);
161 
162   // Flatten expressions and add them to the constraint system.
163   std::vector<SmallVector<int64_t, 8>> flatExprs;
164   FlatAffineValueConstraints localVarCst;
165   if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
166     assert(false && "flattening unimplemented for semi-affine integer sets");
167     return;
168   }
169   assert(flatExprs.size() == set.getNumConstraints());
170   insertVar(VarKind::Local, getNumVarKind(VarKind::Local),
171             /*num=*/localVarCst.getNumLocalVars());
172 
173   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
174     const auto &flatExpr = flatExprs[i];
175     assert(flatExpr.size() == getNumCols());
176     if (set.getEqFlags()[i]) {
177       addEquality(flatExpr);
178     } else {
179       addInequality(flatExpr);
180     }
181   }
182   // Add the other constraints involving local vars from flattening.
183   append(localVarCst);
184 }
185 
186 // Construct a hyperrectangular constraint set from ValueRanges that represent
187 // induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are
188 // expected to match one to one. The order of variables and constraints is:
189 //
190 // ivs | lbs | ubs | eq/ineq
191 // ----+-----+-----+---------
192 //   1   -1     0      >= 0
193 // ----+-----+-----+---------
194 //  -1    0     1      >= 0
195 //
196 // All dimensions as set as VarKind::SetDim.
197 FlatAffineValueConstraints
198 FlatAffineValueConstraints::getHyperrectangular(ValueRange ivs, ValueRange lbs,
199                                                 ValueRange ubs) {
200   FlatAffineValueConstraints res;
201   unsigned nIvs = ivs.size();
202   assert(nIvs == lbs.size() && "expected as many lower bounds as ivs");
203   assert(nIvs == ubs.size() && "expected as many upper bounds as ivs");
204 
205   if (nIvs == 0)
206     return res;
207 
208   res.appendDimVar(ivs);
209   unsigned lbsStart = res.appendDimVar(lbs);
210   unsigned ubsStart = res.appendDimVar(ubs);
211 
212   MLIRContext *ctx = ivs.front().getContext();
213   for (int ivIdx = 0, e = nIvs; ivIdx < e; ++ivIdx) {
214     // iv - lb >= 0
215     AffineMap lb = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
216                                   getAffineDimExpr(lbsStart + ivIdx, ctx));
217     if (failed(res.addBound(BoundType::LB, ivIdx, lb)))
218       llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
219     // -iv + ub >= 0
220     AffineMap ub = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
221                                   getAffineDimExpr(ubsStart + ivIdx, ctx));
222     if (failed(res.addBound(BoundType::UB, ivIdx, ub)))
223       llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
224   }
225   return res;
226 }
227 
228 void FlatAffineValueConstraints::reset(unsigned numReservedInequalities,
229                                        unsigned numReservedEqualities,
230                                        unsigned newNumReservedCols,
231                                        unsigned newNumDims,
232                                        unsigned newNumSymbols,
233                                        unsigned newNumLocals) {
234   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
235          "minimum 1 column");
236   *this = FlatAffineValueConstraints(numReservedInequalities,
237                                      numReservedEqualities, newNumReservedCols,
238                                      newNumDims, newNumSymbols, newNumLocals);
239 }
240 
241 void FlatAffineValueConstraints::reset(unsigned newNumDims,
242                                        unsigned newNumSymbols,
243                                        unsigned newNumLocals) {
244   reset(/*numReservedInequalities=*/0, /*numReservedEqualities=*/0,
245         /*numReservedCols=*/newNumDims + newNumSymbols + newNumLocals + 1,
246         newNumDims, newNumSymbols, newNumLocals);
247 }
248 
249 void FlatAffineValueConstraints::reset(
250     unsigned numReservedInequalities, unsigned numReservedEqualities,
251     unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols,
252     unsigned newNumLocals, ArrayRef<Value> valArgs) {
253   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
254          "minimum 1 column");
255   SmallVector<Optional<Value>, 8> newVals;
256   if (!valArgs.empty())
257     newVals.assign(valArgs.begin(), valArgs.end());
258 
259   *this = FlatAffineValueConstraints(
260       numReservedInequalities, numReservedEqualities, newNumReservedCols,
261       newNumDims, newNumSymbols, newNumLocals, newVals);
262 }
263 
264 void FlatAffineValueConstraints::reset(unsigned newNumDims,
265                                        unsigned newNumSymbols,
266                                        unsigned newNumLocals,
267                                        ArrayRef<Value> valArgs) {
268   reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
269         newNumSymbols, newNumLocals, valArgs);
270 }
271 
272 unsigned FlatAffineValueConstraints::appendDimVar(ValueRange vals) {
273   unsigned pos = getNumDimVars();
274   insertVar(VarKind::SetDim, pos, vals);
275   return pos;
276 }
277 
278 unsigned FlatAffineValueConstraints::appendSymbolVar(ValueRange vals) {
279   unsigned pos = getNumSymbolVars();
280   insertVar(VarKind::Symbol, pos, vals);
281   return pos;
282 }
283 
284 unsigned FlatAffineValueConstraints::insertDimVar(unsigned pos,
285                                                   ValueRange vals) {
286   return insertVar(VarKind::SetDim, pos, vals);
287 }
288 
289 unsigned FlatAffineValueConstraints::insertSymbolVar(unsigned pos,
290                                                      ValueRange vals) {
291   return insertVar(VarKind::Symbol, pos, vals);
292 }
293 
294 unsigned FlatAffineValueConstraints::insertVar(VarKind kind, unsigned pos,
295                                                unsigned num) {
296   unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
297 
298   if (kind != VarKind::Local) {
299     values.insert(values.begin() + absolutePos, num, None);
300     assert(values.size() == getNumDimAndSymbolVars());
301   }
302 
303   return absolutePos;
304 }
305 
306 unsigned FlatAffineValueConstraints::insertVar(VarKind kind, unsigned pos,
307                                                ValueRange vals) {
308   assert(!vals.empty() && "expected ValueRange with Values.");
309   assert(kind != VarKind::Local &&
310          "values cannot be attached to local variables.");
311   unsigned num = vals.size();
312   unsigned absolutePos = IntegerPolyhedron::insertVar(kind, pos, num);
313 
314   // If a Value is provided, insert it; otherwise use None.
315   for (unsigned i = 0; i < num; ++i)
316     values.insert(values.begin() + absolutePos + i,
317                   vals[i] ? Optional<Value>(vals[i]) : None);
318 
319   assert(values.size() == getNumDimAndSymbolVars());
320   return absolutePos;
321 }
322 
323 bool FlatAffineValueConstraints::hasValues() const {
324   return llvm::find_if(values, [](Optional<Value> var) {
325            return var.hasValue();
326          }) != values.end();
327 }
328 
329 /// Checks if two constraint systems are in the same space, i.e., if they are
330 /// associated with the same set of variables, appearing in the same order.
331 static bool areVarsAligned(const FlatAffineValueConstraints &a,
332                            const FlatAffineValueConstraints &b) {
333   return a.getNumDimVars() == b.getNumDimVars() &&
334          a.getNumSymbolVars() == b.getNumSymbolVars() &&
335          a.getNumVars() == b.getNumVars() &&
336          a.getMaybeValues().equals(b.getMaybeValues());
337 }
338 
339 /// Calls areVarsAligned to check if two constraint systems have the same set
340 /// of variables in the same order.
341 bool FlatAffineValueConstraints::areVarsAlignedWithOther(
342     const FlatAffineValueConstraints &other) {
343   return areVarsAligned(*this, other);
344 }
345 
346 /// Checks if the SSA values associated with `cst`'s variables in range
347 /// [start, end) are unique.
348 static bool LLVM_ATTRIBUTE_UNUSED areVarsUnique(
349     const FlatAffineValueConstraints &cst, unsigned start, unsigned end) {
350 
351   assert(start <= cst.getNumDimAndSymbolVars() &&
352          "Start position out of bounds");
353   assert(end <= cst.getNumDimAndSymbolVars() && "End position out of bounds");
354 
355   if (start >= end)
356     return true;
357 
358   SmallPtrSet<Value, 8> uniqueVars;
359   ArrayRef<Optional<Value>> maybeValues =
360       cst.getMaybeValues().slice(start, end - start);
361   for (Optional<Value> val : maybeValues) {
362     if (val && !uniqueVars.insert(*val).second)
363       return false;
364   }
365   return true;
366 }
367 
368 /// Checks if the SSA values associated with `cst`'s variables are unique.
369 static bool LLVM_ATTRIBUTE_UNUSED
370 areVarsUnique(const FlatAffineValueConstraints &cst) {
371   return areVarsUnique(cst, 0, cst.getNumDimAndSymbolVars());
372 }
373 
374 /// Checks if the SSA values associated with `cst`'s variables of kind `kind`
375 /// are unique.
376 static bool LLVM_ATTRIBUTE_UNUSED
377 areVarsUnique(const FlatAffineValueConstraints &cst, VarKind kind) {
378 
379   if (kind == VarKind::SetDim)
380     return areVarsUnique(cst, 0, cst.getNumDimVars());
381   if (kind == VarKind::Symbol)
382     return areVarsUnique(cst, cst.getNumDimVars(),
383                          cst.getNumDimAndSymbolVars());
384   llvm_unreachable("Unexpected VarKind");
385 }
386 
387 /// Merge and align the variables of A and B starting at 'offset', so that
388 /// both constraint systems get the union of the contained variables that is
389 /// dimension-wise and symbol-wise unique; both constraint systems are updated
390 /// so that they have the union of all variables, with A's original
391 /// variables appearing first followed by any of B's variables that didn't
392 /// appear in A. Local variables in B that have the same division
393 /// representation as local variables in A are merged into one.
394 //  E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
395 //        Output: both A, B have (%i, %j, %k) [%M, %N, %P]
396 static void mergeAndAlignVars(unsigned offset, FlatAffineValueConstraints *a,
397                               FlatAffineValueConstraints *b) {
398   assert(offset <= a->getNumDimVars() && offset <= b->getNumDimVars());
399   // A merge/align isn't meaningful if a cst's vars aren't distinct.
400   assert(areVarsUnique(*a) && "A's values aren't unique");
401   assert(areVarsUnique(*b) && "B's values aren't unique");
402 
403   assert(std::all_of(a->getMaybeValues().begin() + offset,
404                      a->getMaybeValues().end(),
405                      [](Optional<Value> var) { return var.hasValue(); }));
406 
407   assert(std::all_of(b->getMaybeValues().begin() + offset,
408                      b->getMaybeValues().end(),
409                      [](Optional<Value> var) { return var.hasValue(); }));
410 
411   SmallVector<Value, 4> aDimValues;
412   a->getValues(offset, a->getNumDimVars(), &aDimValues);
413 
414   {
415     // Merge dims from A into B.
416     unsigned d = offset;
417     for (auto aDimValue : aDimValues) {
418       unsigned loc;
419       if (b->findVar(aDimValue, &loc)) {
420         assert(loc >= offset && "A's dim appears in B's aligned range");
421         assert(loc < b->getNumDimVars() &&
422                "A's dim appears in B's non-dim position");
423         b->swapVar(d, loc);
424       } else {
425         b->insertDimVar(d, aDimValue);
426       }
427       d++;
428     }
429     // Dimensions that are in B, but not in A, are added at the end.
430     for (unsigned t = a->getNumDimVars(), e = b->getNumDimVars(); t < e; t++) {
431       a->appendDimVar(b->getValue(t));
432     }
433     assert(a->getNumDimVars() == b->getNumDimVars() &&
434            "expected same number of dims");
435   }
436 
437   // Merge and align symbols of A and B
438   a->mergeSymbolVars(*b);
439   // Merge and align locals of A and B
440   a->mergeLocalVars(*b);
441 
442   assert(areVarsAligned(*a, *b) && "IDs expected to be aligned");
443 }
444 
445 // Call 'mergeAndAlignVars' to align constraint systems of 'this' and 'other'.
446 void FlatAffineValueConstraints::mergeAndAlignVarsWithOther(
447     unsigned offset, FlatAffineValueConstraints *other) {
448   mergeAndAlignVars(offset, this, other);
449 }
450 
451 LogicalResult
452 FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) {
453   return composeMatchingMap(
454       computeAlignedMap(vMap->getAffineMap(), vMap->getOperands()));
455 }
456 
457 // Similar to `composeMap` except that no Values need be associated with the
458 // constraint system nor are they looked at -- the dimensions and symbols of
459 // `other` are expected to correspond 1:1 to `this` system.
460 LogicalResult FlatAffineValueConstraints::composeMatchingMap(AffineMap other) {
461   assert(other.getNumDims() == getNumDimVars() && "dim mismatch");
462   assert(other.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
463 
464   std::vector<SmallVector<int64_t, 8>> flatExprs;
465   if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
466     return failure();
467   assert(flatExprs.size() == other.getNumResults());
468 
469   // Add dimensions corresponding to the map's results.
470   insertDimVar(/*pos=*/0, /*num=*/other.getNumResults());
471 
472   // We add one equality for each result connecting the result dim of the map to
473   // the other variables.
474   // E.g.: if the expression is 16*i0 + i1, and this is the r^th
475   // iteration/result of the value map, we are adding the equality:
476   // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
477   // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
478   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
479     const auto &flatExpr = flatExprs[r];
480     assert(flatExpr.size() >= other.getNumInputs() + 1);
481 
482     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
483     // Set the coefficient for this result to one.
484     eqToAdd[r] = 1;
485 
486     // Dims and symbols.
487     for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
488       // Negate `eq[r]` since the newly added dimension will be set to this one.
489       eqToAdd[e + i] = -flatExpr[i];
490     }
491     // Local columns of `eq` are at the beginning.
492     unsigned j = getNumDimVars() + getNumSymbolVars();
493     unsigned end = flatExpr.size() - 1;
494     for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
495       eqToAdd[j] = -flatExpr[i];
496     }
497 
498     // Constant term.
499     eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
500 
501     // Add the equality connecting the result of the map to this constraint set.
502     addEquality(eqToAdd);
503   }
504 
505   return success();
506 }
507 
508 // Turn a symbol into a dimension.
509 static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value value) {
510   unsigned pos;
511   if (cst->findVar(value, &pos) && pos >= cst->getNumDimVars() &&
512       pos < cst->getNumDimAndSymbolVars()) {
513     cst->swapVar(pos, cst->getNumDimVars());
514     cst->setDimSymbolSeparation(cst->getNumSymbolVars() - 1);
515   }
516 }
517 
518 /// Merge and align symbols of `this` and `other` such that both get union of
519 /// of symbols that are unique. Symbols in `this` and `other` should be
520 /// unique. Symbols with Value as `None` are considered to be inequal to all
521 /// other symbols.
522 void FlatAffineValueConstraints::mergeSymbolVars(
523     FlatAffineValueConstraints &other) {
524 
525   assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique");
526   assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique");
527 
528   SmallVector<Value, 4> aSymValues;
529   getValues(getNumDimVars(), getNumDimAndSymbolVars(), &aSymValues);
530 
531   // Merge symbols: merge symbols into `other` first from `this`.
532   unsigned s = other.getNumDimVars();
533   for (Value aSymValue : aSymValues) {
534     unsigned loc;
535     // If the var is a symbol in `other`, then align it, otherwise assume that
536     // it is a new symbol
537     if (other.findVar(aSymValue, &loc) && loc >= other.getNumDimVars() &&
538         loc < other.getNumDimAndSymbolVars())
539       other.swapVar(s, loc);
540     else
541       other.insertSymbolVar(s - other.getNumDimVars(), aSymValue);
542     s++;
543   }
544 
545   // Symbols that are in other, but not in this, are added at the end.
546   for (unsigned t = other.getNumDimVars() + getNumSymbolVars(),
547                 e = other.getNumDimAndSymbolVars();
548        t < e; t++)
549     insertSymbolVar(getNumSymbolVars(), other.getValue(t));
550 
551   assert(getNumSymbolVars() == other.getNumSymbolVars() &&
552          "expected same number of symbols");
553   assert(areVarsUnique(*this, VarKind::Symbol) && "Symbol vars are not unique");
554   assert(areVarsUnique(other, VarKind::Symbol) && "Symbol vars are not unique");
555 }
556 
557 // Changes all symbol variables which are loop IVs to dim variables.
558 void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() {
559   // Gather all symbols which are loop IVs.
560   SmallVector<Value, 4> loopIVs;
561   for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++) {
562     if (hasValue(i) && getForInductionVarOwner(getValue(i)))
563       loopIVs.push_back(getValue(i));
564   }
565   // Turn each symbol in 'loopIVs' into a dim variable.
566   for (auto iv : loopIVs) {
567     turnSymbolIntoDim(this, iv);
568   }
569 }
570 
571 void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) {
572   if (containsVar(val))
573     return;
574 
575   // Caller is expected to fully compose map/operands if necessary.
576   assert((isTopLevelValue(val) || isForInductionVar(val)) &&
577          "non-terminal symbol / loop IV expected");
578   // Outer loop IVs could be used in forOp's bounds.
579   if (auto loop = getForInductionVarOwner(val)) {
580     appendDimVar(val);
581     if (failed(this->addAffineForOpDomain(loop)))
582       LLVM_DEBUG(
583           loop.emitWarning("failed to add domain info to constraint system"));
584     return;
585   }
586   // Add top level symbol.
587   appendSymbolVar(val);
588   // Check if the symbol is a constant.
589   if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>())
590     addBound(BoundType::EQ, val, constOp.value());
591 }
592 
593 LogicalResult
594 FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
595   unsigned pos;
596   // Pre-condition for this method.
597   if (!findVar(forOp.getInductionVar(), &pos)) {
598     assert(false && "Value not found");
599     return failure();
600   }
601 
602   int64_t step = forOp.getStep();
603   if (step != 1) {
604     if (!forOp.hasConstantLowerBound())
605       LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated"));
606     else {
607       // Add constraints for the stride.
608       // (iv - lb) % step = 0 can be written as:
609       // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
610       // Add local variable 'q' and add the above equality.
611       // The first constraint is q = (iv - lb) floordiv step
612       SmallVector<int64_t, 8> dividend(getNumCols(), 0);
613       int64_t lb = forOp.getConstantLowerBound();
614       dividend[pos] = 1;
615       dividend.back() -= lb;
616       addLocalFloorDiv(dividend, step);
617       // Second constraint: (iv - lb) - step * q = 0.
618       SmallVector<int64_t, 8> eq(getNumCols(), 0);
619       eq[pos] = 1;
620       eq.back() -= lb;
621       // For the local var just added above.
622       eq[getNumCols() - 2] = -step;
623       addEquality(eq);
624     }
625   }
626 
627   if (forOp.hasConstantLowerBound()) {
628     addBound(BoundType::LB, pos, forOp.getConstantLowerBound());
629   } else {
630     // Non-constant lower bound case.
631     if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(),
632                         forOp.getLowerBoundOperands())))
633       return failure();
634   }
635 
636   if (forOp.hasConstantUpperBound()) {
637     addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1);
638     return success();
639   }
640   // Non-constant upper bound case.
641   return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(),
642                   forOp.getUpperBoundOperands());
643 }
644 
645 LogicalResult
646 FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
647                                                    ArrayRef<AffineMap> ubMaps,
648                                                    ArrayRef<Value> operands) {
649   assert(lbMaps.size() == ubMaps.size());
650   assert(lbMaps.size() <= getNumDimVars());
651 
652   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
653     AffineMap lbMap = lbMaps[i];
654     AffineMap ubMap = ubMaps[i];
655     assert(!lbMap || lbMap.getNumInputs() == operands.size());
656     assert(!ubMap || ubMap.getNumInputs() == operands.size());
657 
658     // Check if this slice is just an equality along this dimension. If so,
659     // retrieve the existing loop it equates to and add it to the system.
660     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
661         ubMap.getNumResults() == 1 &&
662         lbMap.getResult(0) + 1 == ubMap.getResult(0) &&
663         // The condition above will be true for maps describing a single
664         // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
665         // Make sure we skip those cases by checking that the lb result is not
666         // just a constant.
667         !lbMap.getResult(0).isa<AffineConstantExpr>()) {
668       // Limited support: we expect the lb result to be just a loop dimension.
669       // Not supported otherwise for now.
670       AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
671       if (!result)
672         return failure();
673 
674       AffineForOp loop =
675           getForInductionVarOwner(operands[result.getPosition()]);
676       if (!loop)
677         return failure();
678 
679       if (failed(addAffineForOpDomain(loop)))
680         return failure();
681       continue;
682     }
683 
684     // This slice refers to a loop that doesn't exist in the IR yet. Add its
685     // bounds to the system assuming its dimension variable position is the
686     // same as the position of the loop in the loop nest.
687     if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands)))
688       return failure();
689     if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands)))
690       return failure();
691   }
692   return success();
693 }
694 
695 void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
696   // Create the base constraints from the integer set attached to ifOp.
697   FlatAffineValueConstraints cst(ifOp.getIntegerSet());
698 
699   // Bind vars in the constraints to ifOp operands.
700   SmallVector<Value, 4> operands = ifOp.getOperands();
701   cst.setValues(0, cst.getNumDimAndSymbolVars(), operands);
702 
703   // Merge the constraints from ifOp to the current domain. We need first merge
704   // and align the IDs from both constraints, and then append the constraints
705   // from the ifOp into the current one.
706   mergeAndAlignVarsWithOther(0, &cst);
707   append(cst);
708 }
709 
710 bool FlatAffineValueConstraints::hasConsistentState() const {
711   return IntegerPolyhedron::hasConsistentState() &&
712          values.size() == getNumDimAndSymbolVars();
713 }
714 
715 void FlatAffineValueConstraints::removeVarRange(VarKind kind, unsigned varStart,
716                                                 unsigned varLimit) {
717   IntegerPolyhedron::removeVarRange(kind, varStart, varLimit);
718   unsigned offset = getVarKindOffset(kind);
719 
720   if (kind != VarKind::Local) {
721     values.erase(values.begin() + varStart + offset,
722                  values.begin() + varLimit + offset);
723   }
724 }
725 
726 // Determine whether the variable at 'pos' (say var_r) can be expressed as
727 // modulo of another known variable (say var_n) w.r.t a constant. For example,
728 // if the following constraints hold true:
729 // ```
730 // 0 <= var_r <= divisor - 1
731 // var_n - (divisor * q_expr) = var_r
732 // ```
733 // where `var_n` is a known variable (called dividend), and `q_expr` is an
734 // `AffineExpr` (called the quotient expression), `var_r` can be written as:
735 //
736 // `var_r = var_n mod divisor`.
737 //
738 // Additionally, in a special case of the above constaints where `q_expr` is an
739 // variable itself that is not yet known (say `var_q`), it can be written as a
740 // floordiv in the following way:
741 //
742 // `var_q = var_n floordiv divisor`.
743 //
744 // Returns true if the above mod or floordiv are detected, updating 'memo' with
745 // these new expressions. Returns false otherwise.
746 static bool detectAsMod(const FlatAffineValueConstraints &cst, unsigned pos,
747                         int64_t lbConst, int64_t ubConst,
748                         SmallVectorImpl<AffineExpr> &memo,
749                         MLIRContext *context) {
750   assert(pos < cst.getNumVars() && "invalid position");
751 
752   // Check if a divisor satisfying the condition `0 <= var_r <= divisor - 1` can
753   // be determined.
754   if (lbConst != 0 || ubConst < 1)
755     return false;
756   int64_t divisor = ubConst + 1;
757 
758   // Check for the aforementioned conditions in each equality.
759   for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
760        curEquality < numEqualities; curEquality++) {
761     int64_t coefficientAtPos = cst.atEq(curEquality, pos);
762     // If current equality does not involve `var_r`, continue to the next
763     // equality.
764     if (coefficientAtPos == 0)
765       continue;
766 
767     // Constant term should be 0 in this equality.
768     if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
769       continue;
770 
771     // Traverse through the equality and construct the dividend expression
772     // `dividendExpr`, to contain all the variables which are known and are
773     // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
774     // `dividendExpr` gets simplified into a single variable `var_n` discussed
775     // above.
776     auto dividendExpr = getAffineConstantExpr(0, context);
777 
778     // Track the terms that go into quotient expression, later used to detect
779     // additional floordiv.
780     unsigned quotientCount = 0;
781     int quotientPosition = -1;
782     int quotientSign = 1;
783 
784     // Consider each term in the current equality.
785     unsigned curVar, e;
786     for (curVar = 0, e = cst.getNumDimAndSymbolVars(); curVar < e; ++curVar) {
787       // Ignore var_r.
788       if (curVar == pos)
789         continue;
790       int64_t coefficientOfCurVar = cst.atEq(curEquality, curVar);
791       // Ignore vars that do not contribute to the current equality.
792       if (coefficientOfCurVar == 0)
793         continue;
794       // Check if the current var goes into the quotient expression.
795       if (coefficientOfCurVar % (divisor * coefficientAtPos) == 0) {
796         quotientCount++;
797         quotientPosition = curVar;
798         quotientSign = (coefficientOfCurVar * coefficientAtPos) > 0 ? 1 : -1;
799         continue;
800       }
801       // Variables that are part of dividendExpr should be known.
802       if (!memo[curVar])
803         break;
804       // Append the current variable to the dividend expression.
805       dividendExpr = dividendExpr + memo[curVar] * coefficientOfCurVar;
806     }
807 
808     // Can't construct expression as it depends on a yet uncomputed var.
809     if (curVar < e)
810       continue;
811 
812     // Express `var_r` in terms of the other vars collected so far.
813     if (coefficientAtPos > 0)
814       dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
815     else
816       dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
817 
818     // Simplify the expression.
819     dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimVars(),
820                                       cst.getNumSymbolVars());
821     // Only if the final dividend expression is just a single var (which we call
822     // `var_n`), we can proceed.
823     // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
824     // to dims themselves.
825     auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
826     if (!dimExpr)
827       continue;
828 
829     // Express `var_r` as `var_n % divisor` and store the expression in `memo`.
830     if (quotientCount >= 1) {
831       auto ub = cst.getConstantBound(FlatAffineValueConstraints::BoundType::UB,
832                                      dimExpr.getPosition());
833       // If `var_n` has an upperbound that is less than the divisor, mod can be
834       // eliminated altogether.
835       if (ub && *ub < divisor)
836         memo[pos] = dimExpr;
837       else
838         memo[pos] = dimExpr % divisor;
839       // If a unique quotient `var_q` was seen, it can be expressed as
840       // `var_n floordiv divisor`.
841       if (quotientCount == 1 && !memo[quotientPosition])
842         memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
843 
844       return true;
845     }
846   }
847   return false;
848 }
849 
850 /// Check if the pos^th variable can be expressed as a floordiv of an affine
851 /// function of other variables (where the divisor is a positive constant)
852 /// given the initial set of expressions in `exprs`. If it can be, the
853 /// corresponding position in `exprs` is set as the detected affine expr. For
854 /// eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4. An equality can
855 /// also yield a floordiv: eg.  4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
856 /// <= i <= 32q + 31 => q = i floordiv 32.
857 static bool detectAsFloorDiv(const FlatAffineValueConstraints &cst,
858                              unsigned pos, MLIRContext *context,
859                              SmallVectorImpl<AffineExpr> &exprs) {
860   assert(pos < cst.getNumVars() && "invalid position");
861 
862   // Get upper-lower bound pair for this variable.
863   SmallVector<bool, 8> foundRepr(cst.getNumVars(), false);
864   for (unsigned i = 0, e = cst.getNumVars(); i < e; ++i)
865     if (exprs[i])
866       foundRepr[i] = true;
867 
868   SmallVector<int64_t, 8> dividend(cst.getNumCols());
869   unsigned divisor;
870   auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
871 
872   // No upper-lower bound pair found for this var.
873   if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality)
874     return false;
875 
876   // Construct the dividend expression.
877   auto dividendExpr = getAffineConstantExpr(dividend.back(), context);
878   for (unsigned c = 0, f = cst.getNumVars(); c < f; c++)
879     if (dividend[c] != 0)
880       dividendExpr = dividendExpr + dividend[c] * exprs[c];
881 
882   // Successfully detected the floordiv.
883   exprs[pos] = dividendExpr.floorDiv(divisor);
884   return true;
885 }
886 
887 std::pair<AffineMap, AffineMap>
888 FlatAffineValueConstraints::getLowerAndUpperBound(
889     unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
890     ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
891   assert(pos + offset < getNumDimVars() && "invalid dim start pos");
892   assert(symStartPos >= (pos + offset) && "invalid sym start pos");
893   assert(getNumLocalVars() == localExprs.size() &&
894          "incorrect local exprs count");
895 
896   SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
897   getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
898                                offset, num);
899 
900   /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
901   auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
902     b.clear();
903     for (unsigned i = 0, e = a.size(); i < e; ++i) {
904       if (i < offset || i >= offset + num)
905         b.push_back(a[i]);
906     }
907   };
908 
909   SmallVector<int64_t, 8> lb, ub;
910   SmallVector<AffineExpr, 4> lbExprs;
911   unsigned dimCount = symStartPos - num;
912   unsigned symCount = getNumDimAndSymbolVars() - symStartPos;
913   lbExprs.reserve(lbIndices.size() + eqIndices.size());
914   // Lower bound expressions.
915   for (auto idx : lbIndices) {
916     auto ineq = getInequality(idx);
917     // Extract the lower bound (in terms of other coeff's + const), i.e., if
918     // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
919     // - 1.
920     addCoeffs(ineq, lb);
921     std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
922     auto expr =
923         getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
924     // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
925     int64_t divisor = std::abs(ineq[pos + offset]);
926     expr = (expr + divisor - 1).floorDiv(divisor);
927     lbExprs.push_back(expr);
928   }
929 
930   SmallVector<AffineExpr, 4> ubExprs;
931   ubExprs.reserve(ubIndices.size() + eqIndices.size());
932   // Upper bound expressions.
933   for (auto idx : ubIndices) {
934     auto ineq = getInequality(idx);
935     // Extract the upper bound (in terms of other coeff's + const).
936     addCoeffs(ineq, ub);
937     auto expr =
938         getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
939     expr = expr.floorDiv(std::abs(ineq[pos + offset]));
940     // Upper bound is exclusive.
941     ubExprs.push_back(expr + 1);
942   }
943 
944   // Equalities. It's both a lower and a upper bound.
945   SmallVector<int64_t, 4> b;
946   for (auto idx : eqIndices) {
947     auto eq = getEquality(idx);
948     addCoeffs(eq, b);
949     if (eq[pos + offset] > 0)
950       std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
951 
952     // Extract the upper bound (in terms of other coeff's + const).
953     auto expr =
954         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
955     expr = expr.floorDiv(std::abs(eq[pos + offset]));
956     // Upper bound is exclusive.
957     ubExprs.push_back(expr + 1);
958     // Lower bound.
959     expr =
960         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
961     expr = expr.ceilDiv(std::abs(eq[pos + offset]));
962     lbExprs.push_back(expr);
963   }
964 
965   auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
966   auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
967 
968   return {lbMap, ubMap};
969 }
970 
971 /// Computes the lower and upper bounds of the first 'num' dimensional
972 /// variables (starting at 'offset') as affine maps of the remaining
973 /// variables (dimensional and symbolic variables). Local variables are
974 /// themselves explicitly computed as affine functions of other variables in
975 /// this process if needed.
976 void FlatAffineValueConstraints::getSliceBounds(
977     unsigned offset, unsigned num, MLIRContext *context,
978     SmallVectorImpl<AffineMap> *lbMaps, SmallVectorImpl<AffineMap> *ubMaps,
979     bool getClosedUB) {
980   assert(num < getNumDimVars() && "invalid range");
981 
982   // Basic simplification.
983   normalizeConstraintsByGCD();
984 
985   LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
986                           << " variables\n");
987   LLVM_DEBUG(dump());
988 
989   // Record computed/detected variables.
990   SmallVector<AffineExpr, 8> memo(getNumVars());
991   // Initialize dimensional and symbolic variables.
992   for (unsigned i = 0, e = getNumDimVars(); i < e; i++) {
993     if (i < offset)
994       memo[i] = getAffineDimExpr(i, context);
995     else if (i >= offset + num)
996       memo[i] = getAffineDimExpr(i - num, context);
997   }
998   for (unsigned i = getNumDimVars(), e = getNumDimAndSymbolVars(); i < e; i++)
999     memo[i] = getAffineSymbolExpr(i - getNumDimVars(), context);
1000 
1001   bool changed;
1002   do {
1003     changed = false;
1004     // Identify yet unknown variables as constants or mod's / floordiv's of
1005     // other variables if possible.
1006     for (unsigned pos = 0; pos < getNumVars(); pos++) {
1007       if (memo[pos])
1008         continue;
1009 
1010       auto lbConst = getConstantBound(BoundType::LB, pos);
1011       auto ubConst = getConstantBound(BoundType::UB, pos);
1012       if (lbConst.hasValue() && ubConst.hasValue()) {
1013         // Detect equality to a constant.
1014         if (lbConst.getValue() == ubConst.getValue()) {
1015           memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
1016           changed = true;
1017           continue;
1018         }
1019 
1020         // Detect an variable as modulo of another variable w.r.t a
1021         // constant.
1022         if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
1023                         memo, context)) {
1024           changed = true;
1025           continue;
1026         }
1027       }
1028 
1029       // Detect an variable as a floordiv of an affine function of other
1030       // variables (divisor is a positive constant).
1031       if (detectAsFloorDiv(*this, pos, context, memo)) {
1032         changed = true;
1033         continue;
1034       }
1035 
1036       // Detect an variable as an expression of other variables.
1037       unsigned idx;
1038       if (!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) {
1039         continue;
1040       }
1041 
1042       // Build AffineExpr solving for variable 'pos' in terms of all others.
1043       auto expr = getAffineConstantExpr(0, context);
1044       unsigned j, e;
1045       for (j = 0, e = getNumVars(); j < e; ++j) {
1046         if (j == pos)
1047           continue;
1048         int64_t c = atEq(idx, j);
1049         if (c == 0)
1050           continue;
1051         // If any of the involved IDs hasn't been found yet, we can't proceed.
1052         if (!memo[j])
1053           break;
1054         expr = expr + memo[j] * c;
1055       }
1056       if (j < e)
1057         // Can't construct expression as it depends on a yet uncomputed
1058         // variable.
1059         continue;
1060 
1061       // Add constant term to AffineExpr.
1062       expr = expr + atEq(idx, getNumVars());
1063       int64_t vPos = atEq(idx, pos);
1064       assert(vPos != 0 && "expected non-zero here");
1065       if (vPos > 0)
1066         expr = (-expr).floorDiv(vPos);
1067       else
1068         // vPos < 0.
1069         expr = expr.floorDiv(-vPos);
1070       // Successfully constructed expression.
1071       memo[pos] = expr;
1072       changed = true;
1073     }
1074     // This loop is guaranteed to reach a fixed point - since once an
1075     // variable's explicit form is computed (in memo[pos]), it's not updated
1076     // again.
1077   } while (changed);
1078 
1079   int64_t ubAdjustment = getClosedUB ? 0 : 1;
1080 
1081   // Set the lower and upper bound maps for all the variables that were
1082   // computed as affine expressions of the rest as the "detected expr" and
1083   // "detected expr + 1" respectively; set the undetected ones to null.
1084   Optional<FlatAffineValueConstraints> tmpClone;
1085   for (unsigned pos = 0; pos < num; pos++) {
1086     unsigned numMapDims = getNumDimVars() - num;
1087     unsigned numMapSymbols = getNumSymbolVars();
1088     AffineExpr expr = memo[pos + offset];
1089     if (expr)
1090       expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
1091 
1092     AffineMap &lbMap = (*lbMaps)[pos];
1093     AffineMap &ubMap = (*ubMaps)[pos];
1094 
1095     if (expr) {
1096       lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
1097       ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + ubAdjustment);
1098     } else {
1099       // TODO: Whenever there are local variables in the dependence
1100       // constraints, we'll conservatively over-approximate, since we don't
1101       // always explicitly compute them above (in the while loop).
1102       if (getNumLocalVars() == 0) {
1103         // Work on a copy so that we don't update this constraint system.
1104         if (!tmpClone) {
1105           tmpClone.emplace(FlatAffineValueConstraints(*this));
1106           // Removing redundant inequalities is necessary so that we don't get
1107           // redundant loop bounds.
1108           tmpClone->removeRedundantInequalities();
1109         }
1110         std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
1111             pos, offset, num, getNumDimVars(), /*localExprs=*/{}, context);
1112       }
1113 
1114       // If the above fails, we'll just use the constant lower bound and the
1115       // constant upper bound (if they exist) as the slice bounds.
1116       // TODO: being conservative for the moment in cases that
1117       // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
1118       // fixed (b/126426796).
1119       if (!lbMap || lbMap.getNumResults() > 1) {
1120         LLVM_DEBUG(llvm::dbgs()
1121                    << "WARNING: Potentially over-approximating slice lb\n");
1122         auto lbConst = getConstantBound(BoundType::LB, pos + offset);
1123         if (lbConst.hasValue()) {
1124           lbMap = AffineMap::get(
1125               numMapDims, numMapSymbols,
1126               getAffineConstantExpr(lbConst.getValue(), context));
1127         }
1128       }
1129       if (!ubMap || ubMap.getNumResults() > 1) {
1130         LLVM_DEBUG(llvm::dbgs()
1131                    << "WARNING: Potentially over-approximating slice ub\n");
1132         auto ubConst = getConstantBound(BoundType::UB, pos + offset);
1133         if (ubConst.hasValue()) {
1134           ubMap =
1135               AffineMap::get(numMapDims, numMapSymbols,
1136                              getAffineConstantExpr(
1137                                  ubConst.getValue() + ubAdjustment, context));
1138         }
1139       }
1140     }
1141     LLVM_DEBUG(llvm::dbgs()
1142                << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
1143     LLVM_DEBUG(lbMap.dump(););
1144     LLVM_DEBUG(llvm::dbgs()
1145                << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
1146     LLVM_DEBUG(ubMap.dump(););
1147   }
1148 }
1149 
1150 LogicalResult FlatAffineValueConstraints::flattenAlignedMapAndMergeLocals(
1151     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
1152   FlatAffineValueConstraints localCst;
1153   if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) {
1154     LLVM_DEBUG(llvm::dbgs()
1155                << "composition unimplemented for semi-affine maps\n");
1156     return failure();
1157   }
1158 
1159   // Add localCst information.
1160   if (localCst.getNumLocalVars() > 0) {
1161     unsigned numLocalVars = getNumLocalVars();
1162     // Insert local dims of localCst at the beginning.
1163     insertLocalVar(/*pos=*/0, /*num=*/localCst.getNumLocalVars());
1164     // Insert local dims of `this` at the end of localCst.
1165     localCst.appendLocalVar(/*num=*/numLocalVars);
1166     // Dimensions of localCst and this constraint set match. Append localCst to
1167     // this constraint set.
1168     append(localCst);
1169   }
1170 
1171   return success();
1172 }
1173 
1174 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
1175                                                    AffineMap boundMap,
1176                                                    bool isClosedBound) {
1177   assert(boundMap.getNumDims() == getNumDimVars() && "dim mismatch");
1178   assert(boundMap.getNumSymbols() == getNumSymbolVars() && "symbol mismatch");
1179   assert(pos < getNumDimAndSymbolVars() && "invalid position");
1180   assert((type != BoundType::EQ || isClosedBound) &&
1181          "EQ bound must be closed.");
1182 
1183   // Equality follows the logic of lower bound except that we add an equality
1184   // instead of an inequality.
1185   assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
1186          "single result expected");
1187   bool lower = type == BoundType::LB || type == BoundType::EQ;
1188 
1189   std::vector<SmallVector<int64_t, 8>> flatExprs;
1190   if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
1191     return failure();
1192   assert(flatExprs.size() == boundMap.getNumResults());
1193 
1194   // Add one (in)equality for each result.
1195   for (const auto &flatExpr : flatExprs) {
1196     SmallVector<int64_t> ineq(getNumCols(), 0);
1197     // Dims and symbols.
1198     for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
1199       ineq[j] = lower ? -flatExpr[j] : flatExpr[j];
1200     }
1201     // Invalid bound: pos appears in `boundMap`.
1202     // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or
1203     // its callers to prevent invalid bounds from being added.
1204     if (ineq[pos] != 0)
1205       continue;
1206     ineq[pos] = lower ? 1 : -1;
1207     // Local columns of `ineq` are at the beginning.
1208     unsigned j = getNumDimVars() + getNumSymbolVars();
1209     unsigned end = flatExpr.size() - 1;
1210     for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
1211       ineq[j] = lower ? -flatExpr[i] : flatExpr[i];
1212     }
1213     // Make the bound closed in if flatExpr is open. The inequality is always
1214     // created in the upper bound form, so the adjustment is -1.
1215     int64_t boundAdjustment = (isClosedBound || type == BoundType::EQ) ? 0 : -1;
1216     // Constant term.
1217     ineq[getNumCols() - 1] = (lower ? -flatExpr[flatExpr.size() - 1]
1218                                     : flatExpr[flatExpr.size() - 1]) +
1219                              boundAdjustment;
1220     type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq);
1221   }
1222 
1223   return success();
1224 }
1225 
1226 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
1227                                                    AffineMap boundMap) {
1228   return addBound(type, pos, boundMap, /*isClosedBound=*/type != BoundType::UB);
1229 }
1230 
1231 AffineMap
1232 FlatAffineValueConstraints::computeAlignedMap(AffineMap map,
1233                                               ValueRange operands) const {
1234   assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
1235 
1236   SmallVector<Value> dims, syms;
1237 #ifndef NDEBUG
1238   SmallVector<Value> newSyms;
1239   SmallVector<Value> *newSymsPtr = &newSyms;
1240 #else
1241   SmallVector<Value> *newSymsPtr = nullptr;
1242 #endif // NDEBUG
1243 
1244   dims.reserve(getNumDimVars());
1245   syms.reserve(getNumSymbolVars());
1246   for (unsigned i = getVarKindOffset(VarKind::SetDim),
1247                 e = getVarKindEnd(VarKind::SetDim);
1248        i < e; ++i)
1249     dims.push_back(values[i] ? *values[i] : Value());
1250   for (unsigned i = getVarKindOffset(VarKind::Symbol),
1251                 e = getVarKindEnd(VarKind::Symbol);
1252        i < e; ++i)
1253     syms.push_back(values[i] ? *values[i] : Value());
1254 
1255   AffineMap alignedMap =
1256       alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
1257   // All symbols are already part of this FlatAffineConstraints.
1258   assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
1259   assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
1260          "unexpected new/missing symbols");
1261   return alignedMap;
1262 }
1263 
1264 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
1265                                                    AffineMap boundMap,
1266                                                    ValueRange boundOperands) {
1267   // Fully compose map and operands; canonicalize and simplify so that we
1268   // transitively get to terminal symbols or loop IVs.
1269   auto map = boundMap;
1270   SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
1271   fullyComposeAffineMapAndOperands(&map, &operands);
1272   map = simplifyAffineMap(map);
1273   canonicalizeMapAndOperands(&map, &operands);
1274   for (auto operand : operands)
1275     addInductionVarOrTerminalSymbol(operand);
1276   return addBound(type, pos, computeAlignedMap(map, operands));
1277 }
1278 
1279 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
1280 // bounds in 'ubMaps' to each value in `values' that appears in the constraint
1281 // system. Note that both lower/upper bounds share the same operand list
1282 // 'operands'.
1283 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
1284 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
1285 // Note that both lower/upper bounds use operands from 'operands'.
1286 // Returns failure for unimplemented cases such as semi-affine expressions or
1287 // expressions with mod/floordiv.
1288 LogicalResult FlatAffineValueConstraints::addSliceBounds(
1289     ArrayRef<Value> values, ArrayRef<AffineMap> lbMaps,
1290     ArrayRef<AffineMap> ubMaps, ArrayRef<Value> operands) {
1291   assert(values.size() == lbMaps.size());
1292   assert(lbMaps.size() == ubMaps.size());
1293 
1294   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
1295     unsigned pos;
1296     if (!findVar(values[i], &pos))
1297       continue;
1298 
1299     AffineMap lbMap = lbMaps[i];
1300     AffineMap ubMap = ubMaps[i];
1301     assert(!lbMap || lbMap.getNumInputs() == operands.size());
1302     assert(!ubMap || ubMap.getNumInputs() == operands.size());
1303 
1304     // Check if this slice is just an equality along this dimension.
1305     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
1306         ubMap.getNumResults() == 1 &&
1307         lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
1308       if (failed(addBound(BoundType::EQ, pos, lbMap, operands)))
1309         return failure();
1310       continue;
1311     }
1312 
1313     // If lower or upper bound maps are null or provide no results, it implies
1314     // that the source loop was not at all sliced, and the entire loop will be a
1315     // part of the slice.
1316     if (lbMap && lbMap.getNumResults() != 0 && ubMap &&
1317         ubMap.getNumResults() != 0) {
1318       if (failed(addBound(BoundType::LB, pos, lbMap, operands)))
1319         return failure();
1320       if (failed(addBound(BoundType::UB, pos, ubMap, operands)))
1321         return failure();
1322     } else {
1323       auto loop = getForInductionVarOwner(values[i]);
1324       if (failed(this->addAffineForOpDomain(loop)))
1325         return failure();
1326     }
1327   }
1328   return success();
1329 }
1330 
1331 bool FlatAffineValueConstraints::findVar(Value val, unsigned *pos) const {
1332   unsigned i = 0;
1333   for (const auto &mayBeVar : values) {
1334     if (mayBeVar && *mayBeVar == val) {
1335       *pos = i;
1336       return true;
1337     }
1338     i++;
1339   }
1340   return false;
1341 }
1342 
1343 bool FlatAffineValueConstraints::containsVar(Value val) const {
1344   return llvm::any_of(values, [&](const Optional<Value> &mayBeVar) {
1345     return mayBeVar && *mayBeVar == val;
1346   });
1347 }
1348 
1349 void FlatAffineValueConstraints::swapVar(unsigned posA, unsigned posB) {
1350   IntegerPolyhedron::swapVar(posA, posB);
1351 
1352   if (getVarKindAt(posA) == VarKind::Local &&
1353       getVarKindAt(posB) == VarKind::Local)
1354     return;
1355 
1356   // Treat value of a local variable as None.
1357   if (getVarKindAt(posA) == VarKind::Local)
1358     values[posB] = None;
1359   else if (getVarKindAt(posB) == VarKind::Local)
1360     values[posA] = None;
1361   else
1362     std::swap(values[posA], values[posB]);
1363 }
1364 
1365 void FlatAffineValueConstraints::addBound(BoundType type, Value val,
1366                                           int64_t value) {
1367   unsigned pos;
1368   if (!findVar(val, &pos))
1369     // This is a pre-condition for this method.
1370     assert(0 && "var not found");
1371   addBound(type, pos, value);
1372 }
1373 
1374 void FlatAffineValueConstraints::printSpace(raw_ostream &os) const {
1375   IntegerPolyhedron::printSpace(os);
1376   os << "(";
1377   for (unsigned i = 0, e = getNumDimAndSymbolVars(); i < e; i++) {
1378     if (hasValue(i))
1379       os << "Value ";
1380     else
1381       os << "None ";
1382   }
1383   for (unsigned i = getVarKindOffset(VarKind::Local),
1384                 e = getVarKindEnd(VarKind::Local);
1385        i < e; ++i)
1386     os << "Local ";
1387   os << " const)\n";
1388 }
1389 
1390 void FlatAffineValueConstraints::clearAndCopyFrom(
1391     const IntegerRelation &other) {
1392 
1393   if (auto *otherValueSet =
1394           dyn_cast<const FlatAffineValueConstraints>(&other)) {
1395     *this = *otherValueSet;
1396   } else {
1397     *static_cast<IntegerRelation *>(this) = other;
1398     values.clear();
1399     values.resize(getNumDimAndSymbolVars(), None);
1400   }
1401 }
1402 
1403 void FlatAffineValueConstraints::fourierMotzkinEliminate(
1404     unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
1405   SmallVector<Optional<Value>, 8> newVals = values;
1406   if (getVarKindAt(pos) != VarKind::Local)
1407     newVals.erase(newVals.begin() + pos);
1408   // Note: Base implementation discards all associated Values.
1409   IntegerPolyhedron::fourierMotzkinEliminate(pos, darkShadow,
1410                                              isResultIntegerExact);
1411   values = newVals;
1412   assert(values.size() == getNumDimAndSymbolVars());
1413 }
1414 
1415 void FlatAffineValueConstraints::projectOut(Value val) {
1416   unsigned pos;
1417   bool ret = findVar(val, &pos);
1418   assert(ret);
1419   (void)ret;
1420   fourierMotzkinEliminate(pos);
1421 }
1422 
1423 LogicalResult FlatAffineValueConstraints::unionBoundingBox(
1424     const FlatAffineValueConstraints &otherCst) {
1425   assert(otherCst.getNumDimVars() == getNumDimVars() && "dims mismatch");
1426   assert(otherCst.getMaybeValues()
1427              .slice(0, getNumDimVars())
1428              .equals(getMaybeValues().slice(0, getNumDimVars())) &&
1429          "dim values mismatch");
1430   assert(otherCst.getNumLocalVars() == 0 && "local vars not supported here");
1431   assert(getNumLocalVars() == 0 && "local vars not supported yet here");
1432 
1433   // Align `other` to this.
1434   if (!areVarsAligned(*this, otherCst)) {
1435     FlatAffineValueConstraints otherCopy(otherCst);
1436     mergeAndAlignVars(/*offset=*/getNumDimVars(), this, &otherCopy);
1437     return IntegerPolyhedron::unionBoundingBox(otherCopy);
1438   }
1439 
1440   return IntegerPolyhedron::unionBoundingBox(otherCst);
1441 }
1442 
1443 /// Compute an explicit representation for local vars. For all systems coming
1444 /// from MLIR integer sets, maps, or expressions where local vars were
1445 /// introduced to model floordivs and mods, this always succeeds.
1446 static LogicalResult computeLocalVars(const FlatAffineValueConstraints &cst,
1447                                       SmallVectorImpl<AffineExpr> &memo,
1448                                       MLIRContext *context) {
1449   unsigned numDims = cst.getNumDimVars();
1450   unsigned numSyms = cst.getNumSymbolVars();
1451 
1452   // Initialize dimensional and symbolic variables.
1453   for (unsigned i = 0; i < numDims; i++)
1454     memo[i] = getAffineDimExpr(i, context);
1455   for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
1456     memo[i] = getAffineSymbolExpr(i - numDims, context);
1457 
1458   bool changed;
1459   do {
1460     // Each time `changed` is true at the end of this iteration, one or more
1461     // local vars would have been detected as floordivs and set in memo; so the
1462     // number of null entries in memo[...] strictly reduces; so this converges.
1463     changed = false;
1464     for (unsigned i = 0, e = cst.getNumLocalVars(); i < e; ++i)
1465       if (!memo[numDims + numSyms + i] &&
1466           detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
1467         changed = true;
1468   } while (changed);
1469 
1470   ArrayRef<AffineExpr> localExprs =
1471       ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalVars());
1472   return success(
1473       llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
1474 }
1475 
1476 void FlatAffineValueConstraints::getIneqAsAffineValueMap(
1477     unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
1478     MLIRContext *context) const {
1479   unsigned numDims = getNumDimVars();
1480   unsigned numSyms = getNumSymbolVars();
1481 
1482   assert(pos < numDims && "invalid position");
1483   assert(ineqPos < getNumInequalities() && "invalid inequality position");
1484 
1485   // Get expressions for local vars.
1486   SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
1487   if (failed(computeLocalVars(*this, memo, context)))
1488     assert(false &&
1489            "one or more local exprs do not have an explicit representation");
1490   auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars());
1491 
1492   // Compute the AffineExpr lower/upper bound for this inequality.
1493   ArrayRef<int64_t> inequality = getInequality(ineqPos);
1494   SmallVector<int64_t, 8> bound;
1495   bound.reserve(getNumCols() - 1);
1496   // Everything other than the coefficient at `pos`.
1497   bound.append(inequality.begin(), inequality.begin() + pos);
1498   bound.append(inequality.begin() + pos + 1, inequality.end());
1499 
1500   if (inequality[pos] > 0)
1501     // Lower bound.
1502     std::transform(bound.begin(), bound.end(), bound.begin(),
1503                    std::negate<int64_t>());
1504   else
1505     // Upper bound (which is exclusive).
1506     bound.back() += 1;
1507 
1508   // Convert to AffineExpr (tree) form.
1509   auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
1510                                              localExprs, context);
1511 
1512   // Get the values to bind to this affine expr (all dims and symbols).
1513   SmallVector<Value, 4> operands;
1514   getValues(0, pos, &operands);
1515   SmallVector<Value, 4> trailingOperands;
1516   getValues(pos + 1, getNumDimAndSymbolVars(), &trailingOperands);
1517   operands.append(trailingOperands.begin(), trailingOperands.end());
1518   vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
1519 }
1520 
1521 IntegerSet
1522 FlatAffineValueConstraints::getAsIntegerSet(MLIRContext *context) const {
1523   if (getNumConstraints() == 0)
1524     // Return universal set (always true): 0 == 0.
1525     return IntegerSet::get(getNumDimVars(), getNumSymbolVars(),
1526                            getAffineConstantExpr(/*constant=*/0, context),
1527                            /*eqFlags=*/true);
1528 
1529   // Construct local references.
1530   SmallVector<AffineExpr, 8> memo(getNumVars(), AffineExpr());
1531 
1532   if (failed(computeLocalVars(*this, memo, context))) {
1533     // Check if the local variables without an explicit representation have
1534     // zero coefficients everywhere.
1535     SmallVector<unsigned> noLocalRepVars;
1536     unsigned numDimsSymbols = getNumDimAndSymbolVars();
1537     for (unsigned i = numDimsSymbols, e = getNumVars(); i < e; ++i) {
1538       if (!memo[i] && !isColZero(/*pos=*/i))
1539         noLocalRepVars.push_back(i - numDimsSymbols);
1540     }
1541     if (!noLocalRepVars.empty()) {
1542       LLVM_DEBUG({
1543         llvm::dbgs() << "local variables at position(s) ";
1544         llvm::interleaveComma(noLocalRepVars, llvm::dbgs());
1545         llvm::dbgs() << " do not have an explicit representation in:\n";
1546         this->dump();
1547       });
1548       return IntegerSet();
1549     }
1550   }
1551 
1552   ArrayRef<AffineExpr> localExprs =
1553       ArrayRef<AffineExpr>(memo).take_back(getNumLocalVars());
1554 
1555   // Construct the IntegerSet from the equalities/inequalities.
1556   unsigned numDims = getNumDimVars();
1557   unsigned numSyms = getNumSymbolVars();
1558 
1559   SmallVector<bool, 16> eqFlags(getNumConstraints());
1560   std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
1561   std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
1562 
1563   SmallVector<AffineExpr, 8> exprs;
1564   exprs.reserve(getNumConstraints());
1565 
1566   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1567     exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
1568                                               localExprs, context));
1569   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
1570     exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
1571                                               numSyms, localExprs, context));
1572   return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
1573 }
1574 
1575 AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
1576                                          ValueRange dims, ValueRange syms,
1577                                          SmallVector<Value> *newSyms) {
1578   assert(operands.size() == map.getNumInputs() &&
1579          "expected same number of operands and map inputs");
1580   MLIRContext *ctx = map.getContext();
1581   Builder builder(ctx);
1582   SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
1583   unsigned numSymbols = syms.size();
1584   SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
1585   if (newSyms) {
1586     newSyms->clear();
1587     newSyms->append(syms.begin(), syms.end());
1588   }
1589 
1590   for (const auto &operand : llvm::enumerate(operands)) {
1591     // Compute replacement dim/sym of operand.
1592     AffineExpr replacement;
1593     auto dimIt = std::find(dims.begin(), dims.end(), operand.value());
1594     auto symIt = std::find(syms.begin(), syms.end(), operand.value());
1595     if (dimIt != dims.end()) {
1596       replacement =
1597           builder.getAffineDimExpr(std::distance(dims.begin(), dimIt));
1598     } else if (symIt != syms.end()) {
1599       replacement =
1600           builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt));
1601     } else {
1602       // This operand is neither a dimension nor a symbol. Add it as a new
1603       // symbol.
1604       replacement = builder.getAffineSymbolExpr(numSymbols++);
1605       if (newSyms)
1606         newSyms->push_back(operand.value());
1607     }
1608     // Add to corresponding replacements vector.
1609     if (operand.index() < map.getNumDims()) {
1610       dimReplacements[operand.index()] = replacement;
1611     } else {
1612       symReplacements[operand.index() - map.getNumDims()] = replacement;
1613     }
1614   }
1615 
1616   return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
1617                                    dims.size(), numSymbols);
1618 }
1619 
1620 FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const {
1621   FlatAffineValueConstraints domain = *this;
1622   // Convert all range variables to local variables.
1623   domain.convertToLocal(VarKind::SetDim, getNumDomainDims(),
1624                         getNumDomainDims() + getNumRangeDims());
1625   return domain;
1626 }
1627 
1628 FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const {
1629   FlatAffineValueConstraints range = *this;
1630   // Convert all domain variables to local variables.
1631   range.convertToLocal(VarKind::SetDim, 0, getNumDomainDims());
1632   return range;
1633 }
1634 
1635 void FlatAffineRelation::compose(const FlatAffineRelation &other) {
1636   assert(getNumDomainDims() == other.getNumRangeDims() &&
1637          "Domain of this and range of other do not match");
1638   assert(std::equal(values.begin(), values.begin() + getNumDomainDims(),
1639                     other.values.begin() + other.getNumDomainDims()) &&
1640          "Domain of this and range of other do not match");
1641 
1642   FlatAffineRelation rel = other;
1643 
1644   // Convert `rel` from
1645   //    [otherDomain] -> [otherRange]
1646   // to
1647   //    [otherDomain] -> [otherRange thisRange]
1648   // and `this` from
1649   //    [thisDomain] -> [thisRange]
1650   // to
1651   //    [otherDomain thisDomain] -> [thisRange].
1652   unsigned removeDims = rel.getNumRangeDims();
1653   insertDomainVar(0, rel.getNumDomainDims());
1654   rel.appendRangeVar(getNumRangeDims());
1655 
1656   // Merge symbol and local variables.
1657   mergeSymbolVars(rel);
1658   mergeLocalVars(rel);
1659 
1660   // Convert `rel` from [otherDomain] -> [otherRange thisRange] to
1661   // [otherDomain] -> [thisRange] by converting first otherRange range vars
1662   // to local vars.
1663   rel.convertToLocal(VarKind::SetDim, rel.getNumDomainDims(),
1664                      rel.getNumDomainDims() + removeDims);
1665   // Convert `this` from [otherDomain thisDomain] -> [thisRange] to
1666   // [otherDomain] -> [thisRange] by converting last thisDomain domain vars
1667   // to local vars.
1668   convertToLocal(VarKind::SetDim, getNumDomainDims() - removeDims,
1669                  getNumDomainDims());
1670 
1671   auto thisMaybeValues = getMaybeValues(VarKind::SetDim);
1672   auto relMaybeValues = rel.getMaybeValues(VarKind::SetDim);
1673 
1674   // Add and match domain of `rel` to domain of `this`.
1675   for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
1676     if (relMaybeValues[i].hasValue())
1677       setValue(i, relMaybeValues[i].getValue());
1678   // Add and match range of `this` to range of `rel`.
1679   for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) {
1680     unsigned rangeIdx = rel.getNumDomainDims() + i;
1681     if (thisMaybeValues[rangeIdx].hasValue())
1682       rel.setValue(rangeIdx, thisMaybeValues[rangeIdx].getValue());
1683   }
1684 
1685   // Append `this` to `rel` and simplify constraints.
1686   rel.append(*this);
1687   rel.removeRedundantLocalVars();
1688 
1689   *this = rel;
1690 }
1691 
1692 void FlatAffineRelation::inverse() {
1693   unsigned oldDomain = getNumDomainDims();
1694   unsigned oldRange = getNumRangeDims();
1695   // Add new range vars.
1696   appendRangeVar(oldDomain);
1697   // Swap new vars with domain.
1698   for (unsigned i = 0; i < oldDomain; ++i)
1699     swapVar(i, oldDomain + oldRange + i);
1700   // Remove the swapped domain.
1701   removeVarRange(0, oldDomain);
1702   // Set domain and range as inverse.
1703   numDomainDims = oldRange;
1704   numRangeDims = oldDomain;
1705 }
1706 
1707 void FlatAffineRelation::insertDomainVar(unsigned pos, unsigned num) {
1708   assert(pos <= getNumDomainDims() &&
1709          "Var cannot be inserted at invalid position");
1710   insertDimVar(pos, num);
1711   numDomainDims += num;
1712 }
1713 
1714 void FlatAffineRelation::insertRangeVar(unsigned pos, unsigned num) {
1715   assert(pos <= getNumRangeDims() &&
1716          "Var cannot be inserted at invalid position");
1717   insertDimVar(getNumDomainDims() + pos, num);
1718   numRangeDims += num;
1719 }
1720 
1721 void FlatAffineRelation::appendDomainVar(unsigned num) {
1722   insertDimVar(getNumDomainDims(), num);
1723   numDomainDims += num;
1724 }
1725 
1726 void FlatAffineRelation::appendRangeVar(unsigned num) {
1727   insertDimVar(getNumDimVars(), num);
1728   numRangeDims += num;
1729 }
1730 
1731 void FlatAffineRelation::removeVarRange(VarKind kind, unsigned varStart,
1732                                         unsigned varLimit) {
1733   assert(varLimit <= getNumVarKind(kind));
1734   if (varStart >= varLimit)
1735     return;
1736 
1737   FlatAffineValueConstraints::removeVarRange(kind, varStart, varLimit);
1738 
1739   // If kind is not SetDim, domain and range don't need to be updated.
1740   if (kind != VarKind::SetDim)
1741     return;
1742 
1743   // Compute number of domain and range variables to remove. This is done by
1744   // intersecting the range of domain/range vars with range of vars to remove.
1745   unsigned intersectDomainLHS = std::min(varLimit, getNumDomainDims());
1746   unsigned intersectDomainRHS = varStart;
1747   unsigned intersectRangeLHS = std::min(varLimit, getNumDimVars());
1748   unsigned intersectRangeRHS = std::max(varStart, getNumDomainDims());
1749 
1750   if (intersectDomainLHS > intersectDomainRHS)
1751     numDomainDims -= intersectDomainLHS - intersectDomainRHS;
1752   if (intersectRangeLHS > intersectRangeRHS)
1753     numRangeDims -= intersectRangeLHS - intersectRangeRHS;
1754 }
1755 
1756 LogicalResult mlir::getRelationFromMap(AffineMap &map,
1757                                        FlatAffineRelation &rel) {
1758   // Get flattened affine expressions.
1759   std::vector<SmallVector<int64_t, 8>> flatExprs;
1760   FlatAffineValueConstraints localVarCst;
1761   if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst)))
1762     return failure();
1763 
1764   unsigned oldDimNum = localVarCst.getNumDimVars();
1765   unsigned oldCols = localVarCst.getNumCols();
1766   unsigned numRangeVars = map.getNumResults();
1767   unsigned numDomainVars = map.getNumDims();
1768 
1769   // Add range as the new expressions.
1770   localVarCst.appendDimVar(numRangeVars);
1771 
1772   // Add equalities between source and range.
1773   SmallVector<int64_t, 8> eq(localVarCst.getNumCols());
1774   for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1775     // Zero fill.
1776     std::fill(eq.begin(), eq.end(), 0);
1777     // Fill equality.
1778     for (unsigned j = 0, f = oldDimNum; j < f; ++j)
1779       eq[j] = flatExprs[i][j];
1780     for (unsigned j = oldDimNum, f = oldCols; j < f; ++j)
1781       eq[j + numRangeVars] = flatExprs[i][j];
1782     // Set this dimension to -1 to equate lhs and rhs and add equality.
1783     eq[numDomainVars + i] = -1;
1784     localVarCst.addEquality(eq);
1785   }
1786 
1787   // Create relation and return success.
1788   rel = FlatAffineRelation(numDomainVars, numRangeVars, localVarCst);
1789   return success();
1790 }
1791 
1792 LogicalResult mlir::getRelationFromMap(const AffineValueMap &map,
1793                                        FlatAffineRelation &rel) {
1794 
1795   AffineMap affineMap = map.getAffineMap();
1796   if (failed(getRelationFromMap(affineMap, rel)))
1797     return failure();
1798 
1799   // Set symbol values for domain dimensions and symbols.
1800   for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
1801     rel.setValue(i, map.getOperand(i));
1802   for (unsigned i = rel.getNumDimVars(), e = rel.getNumDimAndSymbolVars();
1803        i < e; ++i)
1804     rel.setValue(i, map.getOperand(i - rel.getNumRangeDims()));
1805 
1806   return success();
1807 }
1808