1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Structures for affine/polyhedral analysis of affine dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
14 #include "mlir/Analysis/Presburger/LinearTransform.h"
15 #include "mlir/Analysis/Presburger/Simplex.h"
16 #include "mlir/Analysis/Presburger/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
20 #include "mlir/IR/AffineExprVisitor.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/Support/LLVM.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 #define DEBUG_TYPE "affine-structures"
31 
32 using namespace mlir;
33 using namespace presburger;
34 
35 namespace {
36 
37 // See comments for SimpleAffineExprFlattener.
38 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
39 // constraint information associated with mod's, floordiv's, and ceildiv's
40 // in FlatAffineConstraints 'localVarCst'.
41 struct AffineExprFlattener : public SimpleAffineExprFlattener {
42 public:
43   // Constraints connecting newly introduced local variables (for mod's and
44   // div's) to existing (dimensional and symbolic) ones. These are always
45   // inequalities.
46   FlatAffineConstraints localVarCst;
47 
48   AffineExprFlattener(unsigned nDims, unsigned nSymbols)
49       : SimpleAffineExprFlattener(nDims, nSymbols) {
50     localVarCst.reset(nDims, nSymbols, /*numLocals=*/0);
51   }
52 
53 private:
54   // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
55   // The local identifier added is always a floordiv of a pure add/mul affine
56   // function of other identifiers, coefficients of which are specified in
57   // `dividend' and with respect to the positive constant `divisor'. localExpr
58   // is the simplified tree expression (AffineExpr) corresponding to the
59   // quantifier.
60   void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
61                           AffineExpr localExpr) override {
62     SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
63     // Update localVarCst.
64     localVarCst.addLocalFloorDiv(dividend, divisor);
65   }
66 };
67 
68 } // namespace
69 
70 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
71 // flattened (i.e., semi-affine expressions not handled yet).
72 static LogicalResult
73 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
74                         unsigned numSymbols,
75                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
76                         FlatAffineConstraints *localVarCst) {
77   if (exprs.empty()) {
78     localVarCst->reset(numDims, numSymbols);
79     return success();
80   }
81 
82   AffineExprFlattener flattener(numDims, numSymbols);
83   // Use the same flattener to simplify each expression successively. This way
84   // local identifiers / expressions are shared.
85   for (auto expr : exprs) {
86     if (!expr.isPureAffine())
87       return failure();
88 
89     flattener.walkPostOrder(expr);
90   }
91 
92   assert(flattener.operandExprStack.size() == exprs.size());
93   flattenedExprs->clear();
94   flattenedExprs->assign(flattener.operandExprStack.begin(),
95                          flattener.operandExprStack.end());
96 
97   if (localVarCst)
98     localVarCst->clearAndCopyFrom(flattener.localVarCst);
99 
100   return success();
101 }
102 
103 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
104 // be flattened (semi-affine expressions not handled yet).
105 LogicalResult
106 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
107                              unsigned numSymbols,
108                              SmallVectorImpl<int64_t> *flattenedExpr,
109                              FlatAffineConstraints *localVarCst) {
110   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
111   LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
112                                                 &flattenedExprs, localVarCst);
113   *flattenedExpr = flattenedExprs[0];
114   return ret;
115 }
116 
117 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
118 /// flattened (i.e., semi-affine expressions not handled yet).
119 LogicalResult mlir::getFlattenedAffineExprs(
120     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
121     FlatAffineConstraints *localVarCst) {
122   if (map.getNumResults() == 0) {
123     localVarCst->reset(map.getNumDims(), map.getNumSymbols());
124     return success();
125   }
126   return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
127                                    map.getNumSymbols(), flattenedExprs,
128                                    localVarCst);
129 }
130 
131 LogicalResult mlir::getFlattenedAffineExprs(
132     IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
133     FlatAffineConstraints *localVarCst) {
134   if (set.getNumConstraints() == 0) {
135     localVarCst->reset(set.getNumDims(), set.getNumSymbols());
136     return success();
137   }
138   return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
139                                    set.getNumSymbols(), flattenedExprs,
140                                    localVarCst);
141 }
142 
143 //===----------------------------------------------------------------------===//
144 // FlatAffineConstraints / FlatAffineValueConstraints.
145 //===----------------------------------------------------------------------===//
146 
147 // Clones this object.
148 std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
149   return std::make_unique<FlatAffineConstraints>(*this);
150 }
151 
152 std::unique_ptr<FlatAffineValueConstraints>
153 FlatAffineValueConstraints::clone() const {
154   return std::make_unique<FlatAffineValueConstraints>(*this);
155 }
156 
157 // Construct from an IntegerSet.
158 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
159     : IntegerPolyhedron(set.getNumInequalities(), set.getNumEqualities(),
160                         set.getNumDims() + set.getNumSymbols() + 1,
161                         set.getNumDims(), set.getNumSymbols(),
162                         /*numLocals=*/0) {
163 
164   // Flatten expressions and add them to the constraint system.
165   std::vector<SmallVector<int64_t, 8>> flatExprs;
166   FlatAffineConstraints localVarCst;
167   if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
168     assert(false && "flattening unimplemented for semi-affine integer sets");
169     return;
170   }
171   assert(flatExprs.size() == set.getNumConstraints());
172   appendLocalId(/*num=*/localVarCst.getNumLocalIds());
173 
174   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
175     const auto &flatExpr = flatExprs[i];
176     assert(flatExpr.size() == getNumCols());
177     if (set.getEqFlags()[i]) {
178       addEquality(flatExpr);
179     } else {
180       addInequality(flatExpr);
181     }
182   }
183   // Add the other constraints involving local id's from flattening.
184   append(localVarCst);
185 }
186 
187 // Construct from an IntegerSet.
188 FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set)
189     : FlatAffineConstraints(set) {
190   values.resize(getNumIds(), None);
191 }
192 
193 // Construct a hyperrectangular constraint set from ValueRanges that represent
194 // induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are
195 // expected to match one to one. The order of variables and constraints is:
196 //
197 // ivs | lbs | ubs | eq/ineq
198 // ----+-----+-----+---------
199 //   1   -1     0      >= 0
200 // ----+-----+-----+---------
201 //  -1    0     1      >= 0
202 //
203 // All dimensions as set as DimId.
204 FlatAffineValueConstraints
205 FlatAffineValueConstraints::getHyperrectangular(ValueRange ivs, ValueRange lbs,
206                                                 ValueRange ubs) {
207   FlatAffineValueConstraints res;
208   unsigned nIvs = ivs.size();
209   assert(nIvs == lbs.size() && "expected as many lower bounds as ivs");
210   assert(nIvs == ubs.size() && "expected as many upper bounds as ivs");
211 
212   if (nIvs == 0)
213     return res;
214 
215   res.appendDimId(ivs);
216   unsigned lbsStart = res.appendDimId(lbs);
217   unsigned ubsStart = res.appendDimId(ubs);
218 
219   MLIRContext *ctx = ivs.front().getContext();
220   for (int ivIdx = 0, e = nIvs; ivIdx < e; ++ivIdx) {
221     // iv - lb >= 0
222     AffineMap lb = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
223                                   getAffineDimExpr(lbsStart + ivIdx, ctx));
224     if (failed(res.addBound(BoundType::LB, ivIdx, lb)))
225       llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
226     // -iv + ub >= 0
227     AffineMap ub = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
228                                   getAffineDimExpr(ubsStart + ivIdx, ctx));
229     if (failed(res.addBound(BoundType::UB, ivIdx, ub)))
230       llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
231   }
232   return res;
233 }
234 
235 void FlatAffineConstraints::reset(unsigned numReservedInequalities,
236                                   unsigned numReservedEqualities,
237                                   unsigned newNumReservedCols,
238                                   unsigned newNumDims, unsigned newNumSymbols,
239                                   unsigned newNumLocals) {
240   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
241          "minimum 1 column");
242   *this = FlatAffineConstraints(numReservedInequalities, numReservedEqualities,
243                                 newNumReservedCols, newNumDims, newNumSymbols,
244                                 newNumLocals);
245 }
246 
247 void FlatAffineConstraints::reset(unsigned newNumDims, unsigned newNumSymbols,
248                                   unsigned newNumLocals) {
249   reset(/*numReservedInequalities=*/0, /*numReservedEqualities=*/0,
250         /*numReservedCols=*/newNumDims + newNumSymbols + newNumLocals + 1,
251         newNumDims, newNumSymbols, newNumLocals);
252 }
253 
254 void FlatAffineValueConstraints::reset(unsigned numReservedInequalities,
255                                        unsigned numReservedEqualities,
256                                        unsigned newNumReservedCols,
257                                        unsigned newNumDims,
258                                        unsigned newNumSymbols,
259                                        unsigned newNumLocals) {
260   reset(numReservedInequalities, numReservedEqualities, newNumReservedCols,
261         newNumDims, newNumSymbols, newNumLocals, /*valArgs=*/{});
262 }
263 
264 void FlatAffineValueConstraints::reset(
265     unsigned numReservedInequalities, unsigned numReservedEqualities,
266     unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols,
267     unsigned newNumLocals, ArrayRef<Value> valArgs) {
268   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
269          "minimum 1 column");
270   SmallVector<Optional<Value>, 8> newVals;
271   if (!valArgs.empty())
272     newVals.assign(valArgs.begin(), valArgs.end());
273 
274   *this = FlatAffineValueConstraints(
275       numReservedInequalities, numReservedEqualities, newNumReservedCols,
276       newNumDims, newNumSymbols, newNumLocals, newVals);
277 }
278 
279 void FlatAffineValueConstraints::reset(unsigned newNumDims,
280                                        unsigned newNumSymbols,
281                                        unsigned newNumLocals,
282                                        ArrayRef<Value> valArgs) {
283   reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
284         newNumSymbols, newNumLocals, valArgs);
285 }
286 
287 unsigned FlatAffineValueConstraints::appendDimId(ValueRange vals) {
288   unsigned pos = getNumDimIds();
289   insertId(IdKind::SetDim, pos, vals);
290   return pos;
291 }
292 
293 unsigned FlatAffineValueConstraints::appendSymbolId(ValueRange vals) {
294   unsigned pos = getNumSymbolIds();
295   insertId(IdKind::Symbol, pos, vals);
296   return pos;
297 }
298 
299 unsigned FlatAffineValueConstraints::insertDimId(unsigned pos,
300                                                  ValueRange vals) {
301   return insertId(IdKind::SetDim, pos, vals);
302 }
303 
304 unsigned FlatAffineValueConstraints::insertSymbolId(unsigned pos,
305                                                     ValueRange vals) {
306   return insertId(IdKind::Symbol, pos, vals);
307 }
308 
309 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
310                                               unsigned num) {
311   unsigned absolutePos = FlatAffineConstraints::insertId(kind, pos, num);
312   values.insert(values.begin() + absolutePos, num, None);
313   assert(values.size() == getNumIds());
314   return absolutePos;
315 }
316 
317 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
318                                               ValueRange vals) {
319   assert(!vals.empty() && "expected ValueRange with Values");
320   unsigned num = vals.size();
321   unsigned absolutePos = FlatAffineConstraints::insertId(kind, pos, num);
322 
323   // If a Value is provided, insert it; otherwise use None.
324   for (unsigned i = 0; i < num; ++i)
325     values.insert(values.begin() + absolutePos + i,
326                   vals[i] ? Optional<Value>(vals[i]) : None);
327 
328   assert(values.size() == getNumIds());
329   return absolutePos;
330 }
331 
332 bool FlatAffineValueConstraints::hasValues() const {
333   return llvm::find_if(values, [](Optional<Value> id) {
334            return id.hasValue();
335          }) != values.end();
336 }
337 
338 /// Checks if two constraint systems are in the same space, i.e., if they are
339 /// associated with the same set of identifiers, appearing in the same order.
340 static bool areIdsAligned(const FlatAffineValueConstraints &a,
341                           const FlatAffineValueConstraints &b) {
342   return a.getNumDimIds() == b.getNumDimIds() &&
343          a.getNumSymbolIds() == b.getNumSymbolIds() &&
344          a.getNumIds() == b.getNumIds() &&
345          a.getMaybeValues().equals(b.getMaybeValues());
346 }
347 
348 /// Calls areIdsAligned to check if two constraint systems have the same set
349 /// of identifiers in the same order.
350 bool FlatAffineValueConstraints::areIdsAlignedWithOther(
351     const FlatAffineValueConstraints &other) {
352   return areIdsAligned(*this, other);
353 }
354 
355 /// Checks if the SSA values associated with `cst`'s identifiers in range
356 /// [start, end) are unique.
357 static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique(
358     const FlatAffineValueConstraints &cst, unsigned start, unsigned end) {
359 
360   assert(start <= cst.getNumIds() && "Start position out of bounds");
361   assert(end <= cst.getNumIds() && "End position out of bounds");
362 
363   if (start >= end)
364     return true;
365 
366   SmallPtrSet<Value, 8> uniqueIds;
367   ArrayRef<Optional<Value>> maybeValues = cst.getMaybeValues();
368   for (Optional<Value> val : maybeValues) {
369     if (val.hasValue() && !uniqueIds.insert(val.getValue()).second)
370       return false;
371   }
372   return true;
373 }
374 
375 /// Checks if the SSA values associated with `cst`'s identifiers are unique.
376 static bool LLVM_ATTRIBUTE_UNUSED
377 areIdsUnique(const FlatAffineConstraints &cst) {
378   return areIdsUnique(cst, 0, cst.getNumIds());
379 }
380 
381 /// Checks if the SSA values associated with `cst`'s identifiers of kind `kind`
382 /// are unique.
383 static bool LLVM_ATTRIBUTE_UNUSED
384 areIdsUnique(const FlatAffineValueConstraints &cst, IdKind kind) {
385 
386   if (kind == IdKind::SetDim)
387     return areIdsUnique(cst, 0, cst.getNumDimIds());
388   if (kind == IdKind::Symbol)
389     return areIdsUnique(cst, cst.getNumDimIds(), cst.getNumDimAndSymbolIds());
390   if (kind == IdKind::Local)
391     return areIdsUnique(cst, cst.getNumDimAndSymbolIds(), cst.getNumIds());
392   llvm_unreachable("Unexpected IdKind");
393 }
394 
395 /// Merge and align the identifiers of A and B starting at 'offset', so that
396 /// both constraint systems get the union of the contained identifiers that is
397 /// dimension-wise and symbol-wise unique; both constraint systems are updated
398 /// so that they have the union of all identifiers, with A's original
399 /// identifiers appearing first followed by any of B's identifiers that didn't
400 /// appear in A. Local identifiers in B that have the same division
401 /// representation as local identifiers in A are merged into one.
402 //  E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
403 //        Output: both A, B have (%i, %j, %k) [%M, %N, %P]
404 static void mergeAndAlignIds(unsigned offset, FlatAffineValueConstraints *a,
405                              FlatAffineValueConstraints *b) {
406   assert(offset <= a->getNumDimIds() && offset <= b->getNumDimIds());
407   // A merge/align isn't meaningful if a cst's ids aren't distinct.
408   assert(areIdsUnique(*a) && "A's values aren't unique");
409   assert(areIdsUnique(*b) && "B's values aren't unique");
410 
411   assert(std::all_of(a->getMaybeValues().begin() + offset,
412                      a->getMaybeValues().begin() + a->getNumDimAndSymbolIds(),
413                      [](Optional<Value> id) { return id.hasValue(); }));
414 
415   assert(std::all_of(b->getMaybeValues().begin() + offset,
416                      b->getMaybeValues().begin() + b->getNumDimAndSymbolIds(),
417                      [](Optional<Value> id) { return id.hasValue(); }));
418 
419   SmallVector<Value, 4> aDimValues;
420   a->getValues(offset, a->getNumDimIds(), &aDimValues);
421 
422   {
423     // Merge dims from A into B.
424     unsigned d = offset;
425     for (auto aDimValue : aDimValues) {
426       unsigned loc;
427       if (b->findId(aDimValue, &loc)) {
428         assert(loc >= offset && "A's dim appears in B's aligned range");
429         assert(loc < b->getNumDimIds() &&
430                "A's dim appears in B's non-dim position");
431         b->swapId(d, loc);
432       } else {
433         b->insertDimId(d, aDimValue);
434       }
435       d++;
436     }
437     // Dimensions that are in B, but not in A, are added at the end.
438     for (unsigned t = a->getNumDimIds(), e = b->getNumDimIds(); t < e; t++) {
439       a->appendDimId(b->getValue(t));
440     }
441     assert(a->getNumDimIds() == b->getNumDimIds() &&
442            "expected same number of dims");
443   }
444 
445   // Merge and align symbols of A and B
446   a->mergeSymbolIds(*b);
447   // Merge and align local ids of A and B
448   a->mergeLocalIds(*b);
449 
450   assert(areIdsAligned(*a, *b) && "IDs expected to be aligned");
451 }
452 
453 // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
454 void FlatAffineValueConstraints::mergeAndAlignIdsWithOther(
455     unsigned offset, FlatAffineValueConstraints *other) {
456   mergeAndAlignIds(offset, this, other);
457 }
458 
459 LogicalResult
460 FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) {
461   return composeMatchingMap(
462       computeAlignedMap(vMap->getAffineMap(), vMap->getOperands()));
463 }
464 
465 // Similar to `composeMap` except that no Values need be associated with the
466 // constraint system nor are they looked at -- the dimensions and symbols of
467 // `other` are expected to correspond 1:1 to `this` system.
468 LogicalResult FlatAffineConstraints::composeMatchingMap(AffineMap other) {
469   assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
470   assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
471 
472   std::vector<SmallVector<int64_t, 8>> flatExprs;
473   if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
474     return failure();
475   assert(flatExprs.size() == other.getNumResults());
476 
477   // Add dimensions corresponding to the map's results.
478   insertDimId(/*pos=*/0, /*num=*/other.getNumResults());
479 
480   // We add one equality for each result connecting the result dim of the map to
481   // the other identifiers.
482   // E.g.: if the expression is 16*i0 + i1, and this is the r^th
483   // iteration/result of the value map, we are adding the equality:
484   // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
485   // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
486   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
487     const auto &flatExpr = flatExprs[r];
488     assert(flatExpr.size() >= other.getNumInputs() + 1);
489 
490     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
491     // Set the coefficient for this result to one.
492     eqToAdd[r] = 1;
493 
494     // Dims and symbols.
495     for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
496       // Negate `eq[r]` since the newly added dimension will be set to this one.
497       eqToAdd[e + i] = -flatExpr[i];
498     }
499     // Local columns of `eq` are at the beginning.
500     unsigned j = getNumDimIds() + getNumSymbolIds();
501     unsigned end = flatExpr.size() - 1;
502     for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
503       eqToAdd[j] = -flatExpr[i];
504     }
505 
506     // Constant term.
507     eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
508 
509     // Add the equality connecting the result of the map to this constraint set.
510     addEquality(eqToAdd);
511   }
512 
513   return success();
514 }
515 
516 // Turn a symbol into a dimension.
517 static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value id) {
518   unsigned pos;
519   if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
520       pos < cst->getNumDimAndSymbolIds()) {
521     cst->swapId(pos, cst->getNumDimIds());
522     cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
523   }
524 }
525 
526 /// Merge and align symbols of `this` and `other` such that both get union of
527 /// of symbols that are unique. Symbols in `this` and `other` should be
528 /// unique. Symbols with Value as `None` are considered to be inequal to all
529 /// other symbols.
530 void FlatAffineValueConstraints::mergeSymbolIds(
531     FlatAffineValueConstraints &other) {
532 
533   assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique");
534   assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique");
535 
536   SmallVector<Value, 4> aSymValues;
537   getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues);
538 
539   // Merge symbols: merge symbols into `other` first from `this`.
540   unsigned s = other.getNumDimIds();
541   for (Value aSymValue : aSymValues) {
542     unsigned loc;
543     // If the id is a symbol in `other`, then align it, otherwise assume that
544     // it is a new symbol
545     if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() &&
546         loc < other.getNumDimAndSymbolIds())
547       other.swapId(s, loc);
548     else
549       other.insertSymbolId(s - other.getNumDimIds(), aSymValue);
550     s++;
551   }
552 
553   // Symbols that are in other, but not in this, are added at the end.
554   for (unsigned t = other.getNumDimIds() + getNumSymbolIds(),
555                 e = other.getNumDimAndSymbolIds();
556        t < e; t++)
557     insertSymbolId(getNumSymbolIds(), other.getValue(t));
558 
559   assert(getNumSymbolIds() == other.getNumSymbolIds() &&
560          "expected same number of symbols");
561   assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique");
562   assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique");
563 }
564 
565 // Changes all symbol identifiers which are loop IVs to dim identifiers.
566 void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() {
567   // Gather all symbols which are loop IVs.
568   SmallVector<Value, 4> loopIVs;
569   for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
570     if (hasValue(i) && getForInductionVarOwner(getValue(i)))
571       loopIVs.push_back(getValue(i));
572   }
573   // Turn each symbol in 'loopIVs' into a dim identifier.
574   for (auto iv : loopIVs) {
575     turnSymbolIntoDim(this, iv);
576   }
577 }
578 
579 void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) {
580   if (containsId(val))
581     return;
582 
583   // Caller is expected to fully compose map/operands if necessary.
584   assert((isTopLevelValue(val) || isForInductionVar(val)) &&
585          "non-terminal symbol / loop IV expected");
586   // Outer loop IVs could be used in forOp's bounds.
587   if (auto loop = getForInductionVarOwner(val)) {
588     appendDimId(val);
589     if (failed(this->addAffineForOpDomain(loop)))
590       LLVM_DEBUG(
591           loop.emitWarning("failed to add domain info to constraint system"));
592     return;
593   }
594   // Add top level symbol.
595   appendSymbolId(val);
596   // Check if the symbol is a constant.
597   if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>())
598     addBound(BoundType::EQ, val, constOp.value());
599 }
600 
601 LogicalResult
602 FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
603   unsigned pos;
604   // Pre-condition for this method.
605   if (!findId(forOp.getInductionVar(), &pos)) {
606     assert(false && "Value not found");
607     return failure();
608   }
609 
610   int64_t step = forOp.getStep();
611   if (step != 1) {
612     if (!forOp.hasConstantLowerBound())
613       LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated"));
614     else {
615       // Add constraints for the stride.
616       // (iv - lb) % step = 0 can be written as:
617       // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
618       // Add local variable 'q' and add the above equality.
619       // The first constraint is q = (iv - lb) floordiv step
620       SmallVector<int64_t, 8> dividend(getNumCols(), 0);
621       int64_t lb = forOp.getConstantLowerBound();
622       dividend[pos] = 1;
623       dividend.back() -= lb;
624       addLocalFloorDiv(dividend, step);
625       // Second constraint: (iv - lb) - step * q = 0.
626       SmallVector<int64_t, 8> eq(getNumCols(), 0);
627       eq[pos] = 1;
628       eq.back() -= lb;
629       // For the local var just added above.
630       eq[getNumCols() - 2] = -step;
631       addEquality(eq);
632     }
633   }
634 
635   if (forOp.hasConstantLowerBound()) {
636     addBound(BoundType::LB, pos, forOp.getConstantLowerBound());
637   } else {
638     // Non-constant lower bound case.
639     if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(),
640                         forOp.getLowerBoundOperands())))
641       return failure();
642   }
643 
644   if (forOp.hasConstantUpperBound()) {
645     addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1);
646     return success();
647   }
648   // Non-constant upper bound case.
649   return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(),
650                   forOp.getUpperBoundOperands());
651 }
652 
653 LogicalResult
654 FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
655                                                    ArrayRef<AffineMap> ubMaps,
656                                                    ArrayRef<Value> operands) {
657   assert(lbMaps.size() == ubMaps.size());
658   assert(lbMaps.size() <= getNumDimIds());
659 
660   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
661     AffineMap lbMap = lbMaps[i];
662     AffineMap ubMap = ubMaps[i];
663     assert(!lbMap || lbMap.getNumInputs() == operands.size());
664     assert(!ubMap || ubMap.getNumInputs() == operands.size());
665 
666     // Check if this slice is just an equality along this dimension. If so,
667     // retrieve the existing loop it equates to and add it to the system.
668     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
669         ubMap.getNumResults() == 1 &&
670         lbMap.getResult(0) + 1 == ubMap.getResult(0) &&
671         // The condition above will be true for maps describing a single
672         // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
673         // Make sure we skip those cases by checking that the lb result is not
674         // just a constant.
675         !lbMap.getResult(0).isa<AffineConstantExpr>()) {
676       // Limited support: we expect the lb result to be just a loop dimension.
677       // Not supported otherwise for now.
678       AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
679       if (!result)
680         return failure();
681 
682       AffineForOp loop =
683           getForInductionVarOwner(operands[result.getPosition()]);
684       if (!loop)
685         return failure();
686 
687       if (failed(addAffineForOpDomain(loop)))
688         return failure();
689       continue;
690     }
691 
692     // This slice refers to a loop that doesn't exist in the IR yet. Add its
693     // bounds to the system assuming its dimension identifier position is the
694     // same as the position of the loop in the loop nest.
695     if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands)))
696       return failure();
697     if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands)))
698       return failure();
699   }
700   return success();
701 }
702 
703 void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
704   // Create the base constraints from the integer set attached to ifOp.
705   FlatAffineValueConstraints cst(ifOp.getIntegerSet());
706 
707   // Bind ids in the constraints to ifOp operands.
708   SmallVector<Value, 4> operands = ifOp.getOperands();
709   cst.setValues(0, cst.getNumDimAndSymbolIds(), operands);
710 
711   // Merge the constraints from ifOp to the current domain. We need first merge
712   // and align the IDs from both constraints, and then append the constraints
713   // from the ifOp into the current one.
714   mergeAndAlignIdsWithOther(0, &cst);
715   append(cst);
716 }
717 
718 bool FlatAffineValueConstraints::hasConsistentState() const {
719   return FlatAffineConstraints::hasConsistentState() &&
720          values.size() == getNumIds();
721 }
722 
723 void FlatAffineValueConstraints::removeIdRange(IdKind kind, unsigned idStart,
724                                                unsigned idLimit) {
725   FlatAffineConstraints::removeIdRange(kind, idStart, idLimit);
726   unsigned offset = getIdKindOffset(kind);
727   values.erase(values.begin() + idStart + offset,
728                values.begin() + idLimit + offset);
729 }
730 
731 // Determine whether the identifier at 'pos' (say id_r) can be expressed as
732 // modulo of another known identifier (say id_n) w.r.t a constant. For example,
733 // if the following constraints hold true:
734 // ```
735 // 0 <= id_r <= divisor - 1
736 // id_n - (divisor * q_expr) = id_r
737 // ```
738 // where `id_n` is a known identifier (called dividend), and `q_expr` is an
739 // `AffineExpr` (called the quotient expression), `id_r` can be written as:
740 //
741 // `id_r = id_n mod divisor`.
742 //
743 // Additionally, in a special case of the above constaints where `q_expr` is an
744 // identifier itself that is not yet known (say `id_q`), it can be written as a
745 // floordiv in the following way:
746 //
747 // `id_q = id_n floordiv divisor`.
748 //
749 // Returns true if the above mod or floordiv are detected, updating 'memo' with
750 // these new expressions. Returns false otherwise.
751 static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
752                         int64_t lbConst, int64_t ubConst,
753                         SmallVectorImpl<AffineExpr> &memo,
754                         MLIRContext *context) {
755   assert(pos < cst.getNumIds() && "invalid position");
756 
757   // Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can
758   // be determined.
759   if (lbConst != 0 || ubConst < 1)
760     return false;
761   int64_t divisor = ubConst + 1;
762 
763   // Check for the aforementioned conditions in each equality.
764   for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
765        curEquality < numEqualities; curEquality++) {
766     int64_t coefficientAtPos = cst.atEq(curEquality, pos);
767     // If current equality does not involve `id_r`, continue to the next
768     // equality.
769     if (coefficientAtPos == 0)
770       continue;
771 
772     // Constant term should be 0 in this equality.
773     if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
774       continue;
775 
776     // Traverse through the equality and construct the dividend expression
777     // `dividendExpr`, to contain all the identifiers which are known and are
778     // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
779     // `dividendExpr` gets simplified into a single identifier `id_n` discussed
780     // above.
781     auto dividendExpr = getAffineConstantExpr(0, context);
782 
783     // Track the terms that go into quotient expression, later used to detect
784     // additional floordiv.
785     unsigned quotientCount = 0;
786     int quotientPosition = -1;
787     int quotientSign = 1;
788 
789     // Consider each term in the current equality.
790     unsigned curId, e;
791     for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) {
792       // Ignore id_r.
793       if (curId == pos)
794         continue;
795       int64_t coefficientOfCurId = cst.atEq(curEquality, curId);
796       // Ignore ids that do not contribute to the current equality.
797       if (coefficientOfCurId == 0)
798         continue;
799       // Check if the current id goes into the quotient expression.
800       if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) {
801         quotientCount++;
802         quotientPosition = curId;
803         quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1;
804         continue;
805       }
806       // Identifiers that are part of dividendExpr should be known.
807       if (!memo[curId])
808         break;
809       // Append the current identifier to the dividend expression.
810       dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId;
811     }
812 
813     // Can't construct expression as it depends on a yet uncomputed id.
814     if (curId < e)
815       continue;
816 
817     // Express `id_r` in terms of the other ids collected so far.
818     if (coefficientAtPos > 0)
819       dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
820     else
821       dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
822 
823     // Simplify the expression.
824     dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(),
825                                       cst.getNumSymbolIds());
826     // Only if the final dividend expression is just a single id (which we call
827     // `id_n`), we can proceed.
828     // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
829     // to dims themselves.
830     auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
831     if (!dimExpr)
832       continue;
833 
834     // Express `id_r` as `id_n % divisor` and store the expression in `memo`.
835     if (quotientCount >= 1) {
836       auto ub = cst.getConstantBound(FlatAffineConstraints::BoundType::UB,
837                                      dimExpr.getPosition());
838       // If `id_n` has an upperbound that is less than the divisor, mod can be
839       // eliminated altogether.
840       if (ub.hasValue() && ub.getValue() < divisor)
841         memo[pos] = dimExpr;
842       else
843         memo[pos] = dimExpr % divisor;
844       // If a unique quotient `id_q` was seen, it can be expressed as
845       // `id_n floordiv divisor`.
846       if (quotientCount == 1 && !memo[quotientPosition])
847         memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
848 
849       return true;
850     }
851   }
852   return false;
853 }
854 
855 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
856 /// function of other identifiers (where the divisor is a positive constant)
857 /// given the initial set of expressions in `exprs`. If it can be, the
858 /// corresponding position in `exprs` is set as the detected affine expr. For
859 /// eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4. An equality can
860 /// also yield a floordiv: eg.  4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
861 /// <= i <= 32q + 31 => q = i floordiv 32.
862 static bool detectAsFloorDiv(const FlatAffineConstraints &cst, unsigned pos,
863                              MLIRContext *context,
864                              SmallVectorImpl<AffineExpr> &exprs) {
865   assert(pos < cst.getNumIds() && "invalid position");
866 
867   // Get upper-lower bound pair for this variable.
868   SmallVector<bool, 8> foundRepr(cst.getNumIds(), false);
869   for (unsigned i = 0, e = cst.getNumIds(); i < e; ++i)
870     if (exprs[i])
871       foundRepr[i] = true;
872 
873   SmallVector<int64_t, 8> dividend;
874   unsigned divisor;
875   auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
876 
877   // No upper-lower bound pair found for this var.
878   if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality)
879     return false;
880 
881   // Construct the dividend expression.
882   auto dividendExpr = getAffineConstantExpr(dividend.back(), context);
883   for (unsigned c = 0, f = cst.getNumIds(); c < f; c++)
884     if (dividend[c] != 0)
885       dividendExpr = dividendExpr + dividend[c] * exprs[c];
886 
887   // Successfully detected the floordiv.
888   exprs[pos] = dividendExpr.floorDiv(divisor);
889   return true;
890 }
891 
892 std::pair<AffineMap, AffineMap> FlatAffineConstraints::getLowerAndUpperBound(
893     unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
894     ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
895   assert(pos + offset < getNumDimIds() && "invalid dim start pos");
896   assert(symStartPos >= (pos + offset) && "invalid sym start pos");
897   assert(getNumLocalIds() == localExprs.size() &&
898          "incorrect local exprs count");
899 
900   SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
901   getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
902                                offset, num);
903 
904   /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
905   auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
906     b.clear();
907     for (unsigned i = 0, e = a.size(); i < e; ++i) {
908       if (i < offset || i >= offset + num)
909         b.push_back(a[i]);
910     }
911   };
912 
913   SmallVector<int64_t, 8> lb, ub;
914   SmallVector<AffineExpr, 4> lbExprs;
915   unsigned dimCount = symStartPos - num;
916   unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
917   lbExprs.reserve(lbIndices.size() + eqIndices.size());
918   // Lower bound expressions.
919   for (auto idx : lbIndices) {
920     auto ineq = getInequality(idx);
921     // Extract the lower bound (in terms of other coeff's + const), i.e., if
922     // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
923     // - 1.
924     addCoeffs(ineq, lb);
925     std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
926     auto expr =
927         getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
928     // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
929     int64_t divisor = std::abs(ineq[pos + offset]);
930     expr = (expr + divisor - 1).floorDiv(divisor);
931     lbExprs.push_back(expr);
932   }
933 
934   SmallVector<AffineExpr, 4> ubExprs;
935   ubExprs.reserve(ubIndices.size() + eqIndices.size());
936   // Upper bound expressions.
937   for (auto idx : ubIndices) {
938     auto ineq = getInequality(idx);
939     // Extract the upper bound (in terms of other coeff's + const).
940     addCoeffs(ineq, ub);
941     auto expr =
942         getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
943     expr = expr.floorDiv(std::abs(ineq[pos + offset]));
944     // Upper bound is exclusive.
945     ubExprs.push_back(expr + 1);
946   }
947 
948   // Equalities. It's both a lower and a upper bound.
949   SmallVector<int64_t, 4> b;
950   for (auto idx : eqIndices) {
951     auto eq = getEquality(idx);
952     addCoeffs(eq, b);
953     if (eq[pos + offset] > 0)
954       std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
955 
956     // Extract the upper bound (in terms of other coeff's + const).
957     auto expr =
958         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
959     expr = expr.floorDiv(std::abs(eq[pos + offset]));
960     // Upper bound is exclusive.
961     ubExprs.push_back(expr + 1);
962     // Lower bound.
963     expr =
964         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
965     expr = expr.ceilDiv(std::abs(eq[pos + offset]));
966     lbExprs.push_back(expr);
967   }
968 
969   auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
970   auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
971 
972   return {lbMap, ubMap};
973 }
974 
975 /// Computes the lower and upper bounds of the first 'num' dimensional
976 /// identifiers (starting at 'offset') as affine maps of the remaining
977 /// identifiers (dimensional and symbolic identifiers). Local identifiers are
978 /// themselves explicitly computed as affine functions of other identifiers in
979 /// this process if needed.
980 void FlatAffineConstraints::getSliceBounds(unsigned offset, unsigned num,
981                                            MLIRContext *context,
982                                            SmallVectorImpl<AffineMap> *lbMaps,
983                                            SmallVectorImpl<AffineMap> *ubMaps) {
984   assert(num < getNumDimIds() && "invalid range");
985 
986   // Basic simplification.
987   normalizeConstraintsByGCD();
988 
989   LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
990                           << " identifiers\n");
991   LLVM_DEBUG(dump());
992 
993   // Record computed/detected identifiers.
994   SmallVector<AffineExpr, 8> memo(getNumIds());
995   // Initialize dimensional and symbolic identifiers.
996   for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
997     if (i < offset)
998       memo[i] = getAffineDimExpr(i, context);
999     else if (i >= offset + num)
1000       memo[i] = getAffineDimExpr(i - num, context);
1001   }
1002   for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
1003     memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
1004 
1005   bool changed;
1006   do {
1007     changed = false;
1008     // Identify yet unknown identifiers as constants or mod's / floordiv's of
1009     // other identifiers if possible.
1010     for (unsigned pos = 0; pos < getNumIds(); pos++) {
1011       if (memo[pos])
1012         continue;
1013 
1014       auto lbConst = getConstantBound(BoundType::LB, pos);
1015       auto ubConst = getConstantBound(BoundType::UB, pos);
1016       if (lbConst.hasValue() && ubConst.hasValue()) {
1017         // Detect equality to a constant.
1018         if (lbConst.getValue() == ubConst.getValue()) {
1019           memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
1020           changed = true;
1021           continue;
1022         }
1023 
1024         // Detect an identifier as modulo of another identifier w.r.t a
1025         // constant.
1026         if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
1027                         memo, context)) {
1028           changed = true;
1029           continue;
1030         }
1031       }
1032 
1033       // Detect an identifier as a floordiv of an affine function of other
1034       // identifiers (divisor is a positive constant).
1035       if (detectAsFloorDiv(*this, pos, context, memo)) {
1036         changed = true;
1037         continue;
1038       }
1039 
1040       // Detect an identifier as an expression of other identifiers.
1041       unsigned idx;
1042       if (!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) {
1043         continue;
1044       }
1045 
1046       // Build AffineExpr solving for identifier 'pos' in terms of all others.
1047       auto expr = getAffineConstantExpr(0, context);
1048       unsigned j, e;
1049       for (j = 0, e = getNumIds(); j < e; ++j) {
1050         if (j == pos)
1051           continue;
1052         int64_t c = atEq(idx, j);
1053         if (c == 0)
1054           continue;
1055         // If any of the involved IDs hasn't been found yet, we can't proceed.
1056         if (!memo[j])
1057           break;
1058         expr = expr + memo[j] * c;
1059       }
1060       if (j < e)
1061         // Can't construct expression as it depends on a yet uncomputed
1062         // identifier.
1063         continue;
1064 
1065       // Add constant term to AffineExpr.
1066       expr = expr + atEq(idx, getNumIds());
1067       int64_t vPos = atEq(idx, pos);
1068       assert(vPos != 0 && "expected non-zero here");
1069       if (vPos > 0)
1070         expr = (-expr).floorDiv(vPos);
1071       else
1072         // vPos < 0.
1073         expr = expr.floorDiv(-vPos);
1074       // Successfully constructed expression.
1075       memo[pos] = expr;
1076       changed = true;
1077     }
1078     // This loop is guaranteed to reach a fixed point - since once an
1079     // identifier's explicit form is computed (in memo[pos]), it's not updated
1080     // again.
1081   } while (changed);
1082 
1083   // Set the lower and upper bound maps for all the identifiers that were
1084   // computed as affine expressions of the rest as the "detected expr" and
1085   // "detected expr + 1" respectively; set the undetected ones to null.
1086   Optional<FlatAffineConstraints> tmpClone;
1087   for (unsigned pos = 0; pos < num; pos++) {
1088     unsigned numMapDims = getNumDimIds() - num;
1089     unsigned numMapSymbols = getNumSymbolIds();
1090     AffineExpr expr = memo[pos + offset];
1091     if (expr)
1092       expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
1093 
1094     AffineMap &lbMap = (*lbMaps)[pos];
1095     AffineMap &ubMap = (*ubMaps)[pos];
1096 
1097     if (expr) {
1098       lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
1099       ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
1100     } else {
1101       // TODO: Whenever there are local identifiers in the dependence
1102       // constraints, we'll conservatively over-approximate, since we don't
1103       // always explicitly compute them above (in the while loop).
1104       if (getNumLocalIds() == 0) {
1105         // Work on a copy so that we don't update this constraint system.
1106         if (!tmpClone) {
1107           tmpClone.emplace(FlatAffineConstraints(*this));
1108           // Removing redundant inequalities is necessary so that we don't get
1109           // redundant loop bounds.
1110           tmpClone->removeRedundantInequalities();
1111         }
1112         std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
1113             pos, offset, num, getNumDimIds(), /*localExprs=*/{}, context);
1114       }
1115 
1116       // If the above fails, we'll just use the constant lower bound and the
1117       // constant upper bound (if they exist) as the slice bounds.
1118       // TODO: being conservative for the moment in cases that
1119       // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
1120       // fixed (b/126426796).
1121       if (!lbMap || lbMap.getNumResults() > 1) {
1122         LLVM_DEBUG(llvm::dbgs()
1123                    << "WARNING: Potentially over-approximating slice lb\n");
1124         auto lbConst = getConstantBound(BoundType::LB, pos + offset);
1125         if (lbConst.hasValue()) {
1126           lbMap = AffineMap::get(
1127               numMapDims, numMapSymbols,
1128               getAffineConstantExpr(lbConst.getValue(), context));
1129         }
1130       }
1131       if (!ubMap || ubMap.getNumResults() > 1) {
1132         LLVM_DEBUG(llvm::dbgs()
1133                    << "WARNING: Potentially over-approximating slice ub\n");
1134         auto ubConst = getConstantBound(BoundType::UB, pos + offset);
1135         if (ubConst.hasValue()) {
1136           (ubMap) = AffineMap::get(
1137               numMapDims, numMapSymbols,
1138               getAffineConstantExpr(ubConst.getValue() + 1, context));
1139         }
1140       }
1141     }
1142     LLVM_DEBUG(llvm::dbgs()
1143                << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
1144     LLVM_DEBUG(lbMap.dump(););
1145     LLVM_DEBUG(llvm::dbgs()
1146                << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
1147     LLVM_DEBUG(ubMap.dump(););
1148   }
1149 }
1150 
1151 LogicalResult FlatAffineConstraints::flattenAlignedMapAndMergeLocals(
1152     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
1153   FlatAffineConstraints localCst;
1154   if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) {
1155     LLVM_DEBUG(llvm::dbgs()
1156                << "composition unimplemented for semi-affine maps\n");
1157     return failure();
1158   }
1159 
1160   // Add localCst information.
1161   if (localCst.getNumLocalIds() > 0) {
1162     unsigned numLocalIds = getNumLocalIds();
1163     // Insert local dims of localCst at the beginning.
1164     insertLocalId(/*pos=*/0, /*num=*/localCst.getNumLocalIds());
1165     // Insert local dims of `this` at the end of localCst.
1166     localCst.appendLocalId(/*num=*/numLocalIds);
1167     // Dimensions of localCst and this constraint set match. Append localCst to
1168     // this constraint set.
1169     append(localCst);
1170   }
1171 
1172   return success();
1173 }
1174 
1175 LogicalResult FlatAffineConstraints::addBound(BoundType type, unsigned pos,
1176                                               AffineMap boundMap) {
1177   assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch");
1178   assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
1179   assert(pos < getNumDimAndSymbolIds() && "invalid position");
1180 
1181   // Equality follows the logic of lower bound except that we add an equality
1182   // instead of an inequality.
1183   assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
1184          "single result expected");
1185   bool lower = type == BoundType::LB || type == BoundType::EQ;
1186 
1187   std::vector<SmallVector<int64_t, 8>> flatExprs;
1188   if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
1189     return failure();
1190   assert(flatExprs.size() == boundMap.getNumResults());
1191 
1192   // Add one (in)equality for each result.
1193   for (const auto &flatExpr : flatExprs) {
1194     SmallVector<int64_t> ineq(getNumCols(), 0);
1195     // Dims and symbols.
1196     for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
1197       ineq[j] = lower ? -flatExpr[j] : flatExpr[j];
1198     }
1199     // Invalid bound: pos appears in `boundMap`.
1200     // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or
1201     // its callers to prevent invalid bounds from being added.
1202     if (ineq[pos] != 0)
1203       continue;
1204     ineq[pos] = lower ? 1 : -1;
1205     // Local columns of `ineq` are at the beginning.
1206     unsigned j = getNumDimIds() + getNumSymbolIds();
1207     unsigned end = flatExpr.size() - 1;
1208     for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
1209       ineq[j] = lower ? -flatExpr[i] : flatExpr[i];
1210     }
1211     // Constant term.
1212     ineq[getNumCols() - 1] =
1213         lower ? -flatExpr[flatExpr.size() - 1]
1214               // Upper bound in flattenedExpr is an exclusive one.
1215               : flatExpr[flatExpr.size() - 1] - 1;
1216     type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq);
1217   }
1218 
1219   return success();
1220 }
1221 
1222 AffineMap
1223 FlatAffineValueConstraints::computeAlignedMap(AffineMap map,
1224                                               ValueRange operands) const {
1225   assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
1226 
1227   SmallVector<Value> dims, syms;
1228 #ifndef NDEBUG
1229   SmallVector<Value> newSyms;
1230   SmallVector<Value> *newSymsPtr = &newSyms;
1231 #else
1232   SmallVector<Value> *newSymsPtr = nullptr;
1233 #endif // NDEBUG
1234 
1235   dims.reserve(getNumDimIds());
1236   syms.reserve(getNumSymbolIds());
1237   for (unsigned i = getIdKindOffset(IdKind::SetDim),
1238                 e = getIdKindEnd(IdKind::SetDim);
1239        i < e; ++i)
1240     dims.push_back(values[i] ? *values[i] : Value());
1241   for (unsigned i = getIdKindOffset(IdKind::Symbol),
1242                 e = getIdKindEnd(IdKind::Symbol);
1243        i < e; ++i)
1244     syms.push_back(values[i] ? *values[i] : Value());
1245 
1246   AffineMap alignedMap =
1247       alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
1248   // All symbols are already part of this FlatAffineConstraints.
1249   assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
1250   assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
1251          "unexpected new/missing symbols");
1252   return alignedMap;
1253 }
1254 
1255 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
1256                                                    AffineMap boundMap,
1257                                                    ValueRange boundOperands) {
1258   // Fully compose map and operands; canonicalize and simplify so that we
1259   // transitively get to terminal symbols or loop IVs.
1260   auto map = boundMap;
1261   SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
1262   fullyComposeAffineMapAndOperands(&map, &operands);
1263   map = simplifyAffineMap(map);
1264   canonicalizeMapAndOperands(&map, &operands);
1265   for (auto operand : operands)
1266     addInductionVarOrTerminalSymbol(operand);
1267   return addBound(type, pos, computeAlignedMap(map, operands));
1268 }
1269 
1270 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
1271 // bounds in 'ubMaps' to each value in `values' that appears in the constraint
1272 // system. Note that both lower/upper bounds share the same operand list
1273 // 'operands'.
1274 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
1275 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
1276 // Note that both lower/upper bounds use operands from 'operands'.
1277 // Returns failure for unimplemented cases such as semi-affine expressions or
1278 // expressions with mod/floordiv.
1279 LogicalResult FlatAffineValueConstraints::addSliceBounds(
1280     ArrayRef<Value> values, ArrayRef<AffineMap> lbMaps,
1281     ArrayRef<AffineMap> ubMaps, ArrayRef<Value> operands) {
1282   assert(values.size() == lbMaps.size());
1283   assert(lbMaps.size() == ubMaps.size());
1284 
1285   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
1286     unsigned pos;
1287     if (!findId(values[i], &pos))
1288       continue;
1289 
1290     AffineMap lbMap = lbMaps[i];
1291     AffineMap ubMap = ubMaps[i];
1292     assert(!lbMap || lbMap.getNumInputs() == operands.size());
1293     assert(!ubMap || ubMap.getNumInputs() == operands.size());
1294 
1295     // Check if this slice is just an equality along this dimension.
1296     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
1297         ubMap.getNumResults() == 1 &&
1298         lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
1299       if (failed(addBound(BoundType::EQ, pos, lbMap, operands)))
1300         return failure();
1301       continue;
1302     }
1303 
1304     // If lower or upper bound maps are null or provide no results, it implies
1305     // that the source loop was not at all sliced, and the entire loop will be a
1306     // part of the slice.
1307     if (lbMap && lbMap.getNumResults() != 0 && ubMap &&
1308         ubMap.getNumResults() != 0) {
1309       if (failed(addBound(BoundType::LB, pos, lbMap, operands)))
1310         return failure();
1311       if (failed(addBound(BoundType::UB, pos, ubMap, operands)))
1312         return failure();
1313     } else {
1314       auto loop = getForInductionVarOwner(values[i]);
1315       if (failed(this->addAffineForOpDomain(loop)))
1316         return failure();
1317     }
1318   }
1319   return success();
1320 }
1321 
1322 bool FlatAffineValueConstraints::findId(Value val, unsigned *pos) const {
1323   unsigned i = 0;
1324   for (const auto &mayBeId : values) {
1325     if (mayBeId.hasValue() && mayBeId.getValue() == val) {
1326       *pos = i;
1327       return true;
1328     }
1329     i++;
1330   }
1331   return false;
1332 }
1333 
1334 bool FlatAffineValueConstraints::containsId(Value val) const {
1335   return llvm::any_of(values, [&](const Optional<Value> &mayBeId) {
1336     return mayBeId.hasValue() && mayBeId.getValue() == val;
1337   });
1338 }
1339 
1340 void FlatAffineValueConstraints::swapId(unsigned posA, unsigned posB) {
1341   FlatAffineConstraints::swapId(posA, posB);
1342   std::swap(values[posA], values[posB]);
1343 }
1344 
1345 void FlatAffineValueConstraints::addBound(BoundType type, Value val,
1346                                           int64_t value) {
1347   unsigned pos;
1348   if (!findId(val, &pos))
1349     // This is a pre-condition for this method.
1350     assert(0 && "id not found");
1351   addBound(type, pos, value);
1352 }
1353 
1354 void FlatAffineConstraints::printSpace(raw_ostream &os) const {
1355   IntegerPolyhedron::printSpace(os);
1356   os << "(";
1357   for (unsigned i = 0, e = getNumIds(); i < e; i++) {
1358     if (auto *valueCstr = dyn_cast<const FlatAffineValueConstraints>(this)) {
1359       if (valueCstr->hasValue(i))
1360         os << "Value ";
1361       else
1362         os << "None ";
1363     } else {
1364       os << "None ";
1365     }
1366   }
1367   os << " const)\n";
1368 }
1369 
1370 void FlatAffineConstraints::clearAndCopyFrom(const IntegerRelation &other) {
1371   if (auto *otherValueSet = dyn_cast<const FlatAffineValueConstraints>(&other))
1372     assert(!otherValueSet->hasValues() &&
1373            "cannot copy associated Values into FlatAffineConstraints");
1374 
1375   // Note: Assigment operator does not vtable pointer, so kind does not
1376   // change.
1377   if (auto *otherValueSet = dyn_cast<const FlatAffineConstraints>(&other))
1378     *this = *otherValueSet;
1379   else
1380     *static_cast<IntegerRelation *>(this) = other;
1381 }
1382 
1383 void FlatAffineValueConstraints::clearAndCopyFrom(
1384     const IntegerRelation &other) {
1385 
1386   if (auto *otherValueSet =
1387           dyn_cast<const FlatAffineValueConstraints>(&other)) {
1388     *this = *otherValueSet;
1389     return;
1390   }
1391 
1392   if (auto *otherValueSet = dyn_cast<const FlatAffineValueConstraints>(&other))
1393     *static_cast<FlatAffineConstraints *>(this) = *otherValueSet;
1394   else
1395     *static_cast<IntegerRelation *>(this) = other;
1396 
1397   values.clear();
1398   values.resize(getNumIds(), None);
1399 }
1400 
1401 void FlatAffineValueConstraints::fourierMotzkinEliminate(
1402     unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
1403   SmallVector<Optional<Value>, 8> newVals;
1404   newVals.reserve(getNumIds() - 1);
1405   newVals.append(values.begin(), values.begin() + pos);
1406   newVals.append(values.begin() + pos + 1, values.end());
1407   // Note: Base implementation discards all associated Values.
1408   FlatAffineConstraints::fourierMotzkinEliminate(pos, darkShadow,
1409                                                  isResultIntegerExact);
1410   values = newVals;
1411   assert(values.size() == getNumIds());
1412 }
1413 
1414 void FlatAffineValueConstraints::projectOut(Value val) {
1415   unsigned pos;
1416   bool ret = findId(val, &pos);
1417   assert(ret);
1418   (void)ret;
1419   fourierMotzkinEliminate(pos);
1420 }
1421 
1422 LogicalResult FlatAffineValueConstraints::unionBoundingBox(
1423     const FlatAffineValueConstraints &otherCst) {
1424   assert(otherCst.getNumDimIds() == getNumDimIds() && "dims mismatch");
1425   assert(otherCst.getMaybeValues()
1426              .slice(0, getNumDimIds())
1427              .equals(getMaybeValues().slice(0, getNumDimIds())) &&
1428          "dim values mismatch");
1429   assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
1430   assert(getNumLocalIds() == 0 && "local ids not supported yet here");
1431 
1432   // Align `other` to this.
1433   if (!areIdsAligned(*this, otherCst)) {
1434     FlatAffineValueConstraints otherCopy(otherCst);
1435     mergeAndAlignIds(/*offset=*/getNumDimIds(), this, &otherCopy);
1436     return FlatAffineConstraints::unionBoundingBox(otherCopy);
1437   }
1438 
1439   return FlatAffineConstraints::unionBoundingBox(otherCst);
1440 }
1441 
1442 /// Compute an explicit representation for local vars. For all systems coming
1443 /// from MLIR integer sets, maps, or expressions where local vars were
1444 /// introduced to model floordivs and mods, this always succeeds.
1445 static LogicalResult computeLocalVars(const FlatAffineConstraints &cst,
1446                                       SmallVectorImpl<AffineExpr> &memo,
1447                                       MLIRContext *context) {
1448   unsigned numDims = cst.getNumDimIds();
1449   unsigned numSyms = cst.getNumSymbolIds();
1450 
1451   // Initialize dimensional and symbolic identifiers.
1452   for (unsigned i = 0; i < numDims; i++)
1453     memo[i] = getAffineDimExpr(i, context);
1454   for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
1455     memo[i] = getAffineSymbolExpr(i - numDims, context);
1456 
1457   bool changed;
1458   do {
1459     // Each time `changed` is true at the end of this iteration, one or more
1460     // local vars would have been detected as floordivs and set in memo; so the
1461     // number of null entries in memo[...] strictly reduces; so this converges.
1462     changed = false;
1463     for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i)
1464       if (!memo[numDims + numSyms + i] &&
1465           detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
1466         changed = true;
1467   } while (changed);
1468 
1469   ArrayRef<AffineExpr> localExprs =
1470       ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds());
1471   return success(
1472       llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
1473 }
1474 
1475 void FlatAffineValueConstraints::getIneqAsAffineValueMap(
1476     unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
1477     MLIRContext *context) const {
1478   unsigned numDims = getNumDimIds();
1479   unsigned numSyms = getNumSymbolIds();
1480 
1481   assert(pos < numDims && "invalid position");
1482   assert(ineqPos < getNumInequalities() && "invalid inequality position");
1483 
1484   // Get expressions for local vars.
1485   SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
1486   if (failed(computeLocalVars(*this, memo, context)))
1487     assert(false &&
1488            "one or more local exprs do not have an explicit representation");
1489   auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
1490 
1491   // Compute the AffineExpr lower/upper bound for this inequality.
1492   ArrayRef<int64_t> inequality = getInequality(ineqPos);
1493   SmallVector<int64_t, 8> bound;
1494   bound.reserve(getNumCols() - 1);
1495   // Everything other than the coefficient at `pos`.
1496   bound.append(inequality.begin(), inequality.begin() + pos);
1497   bound.append(inequality.begin() + pos + 1, inequality.end());
1498 
1499   if (inequality[pos] > 0)
1500     // Lower bound.
1501     std::transform(bound.begin(), bound.end(), bound.begin(),
1502                    std::negate<int64_t>());
1503   else
1504     // Upper bound (which is exclusive).
1505     bound.back() += 1;
1506 
1507   // Convert to AffineExpr (tree) form.
1508   auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
1509                                              localExprs, context);
1510 
1511   // Get the values to bind to this affine expr (all dims and symbols).
1512   SmallVector<Value, 4> operands;
1513   getValues(0, pos, &operands);
1514   SmallVector<Value, 4> trailingOperands;
1515   getValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands);
1516   operands.append(trailingOperands.begin(), trailingOperands.end());
1517   vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
1518 }
1519 
1520 IntegerSet FlatAffineConstraints::getAsIntegerSet(MLIRContext *context) const {
1521   if (getNumConstraints() == 0)
1522     // Return universal set (always true): 0 == 0.
1523     return IntegerSet::get(getNumDimIds(), getNumSymbolIds(),
1524                            getAffineConstantExpr(/*constant=*/0, context),
1525                            /*eqFlags=*/true);
1526 
1527   // Construct local references.
1528   SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
1529 
1530   if (failed(computeLocalVars(*this, memo, context))) {
1531     // Check if the local variables without an explicit representation have
1532     // zero coefficients everywhere.
1533     SmallVector<unsigned> noLocalRepVars;
1534     unsigned numDimsSymbols = getNumDimAndSymbolIds();
1535     for (unsigned i = numDimsSymbols, e = getNumIds(); i < e; ++i) {
1536       if (!memo[i] && !isColZero(/*pos=*/i))
1537         noLocalRepVars.push_back(i - numDimsSymbols);
1538     }
1539     if (!noLocalRepVars.empty()) {
1540       LLVM_DEBUG({
1541         llvm::dbgs() << "local variables at position(s) ";
1542         llvm::interleaveComma(noLocalRepVars, llvm::dbgs());
1543         llvm::dbgs() << " do not have an explicit representation in:\n";
1544         this->dump();
1545       });
1546       return IntegerSet();
1547     }
1548   }
1549 
1550   ArrayRef<AffineExpr> localExprs =
1551       ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
1552 
1553   // Construct the IntegerSet from the equalities/inequalities.
1554   unsigned numDims = getNumDimIds();
1555   unsigned numSyms = getNumSymbolIds();
1556 
1557   SmallVector<bool, 16> eqFlags(getNumConstraints());
1558   std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
1559   std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
1560 
1561   SmallVector<AffineExpr, 8> exprs;
1562   exprs.reserve(getNumConstraints());
1563 
1564   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1565     exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
1566                                               localExprs, context));
1567   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
1568     exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
1569                                               numSyms, localExprs, context));
1570   return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
1571 }
1572 
1573 AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
1574                                          ValueRange dims, ValueRange syms,
1575                                          SmallVector<Value> *newSyms) {
1576   assert(operands.size() == map.getNumInputs() &&
1577          "expected same number of operands and map inputs");
1578   MLIRContext *ctx = map.getContext();
1579   Builder builder(ctx);
1580   SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
1581   unsigned numSymbols = syms.size();
1582   SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
1583   if (newSyms) {
1584     newSyms->clear();
1585     newSyms->append(syms.begin(), syms.end());
1586   }
1587 
1588   for (const auto &operand : llvm::enumerate(operands)) {
1589     // Compute replacement dim/sym of operand.
1590     AffineExpr replacement;
1591     auto dimIt = std::find(dims.begin(), dims.end(), operand.value());
1592     auto symIt = std::find(syms.begin(), syms.end(), operand.value());
1593     if (dimIt != dims.end()) {
1594       replacement =
1595           builder.getAffineDimExpr(std::distance(dims.begin(), dimIt));
1596     } else if (symIt != syms.end()) {
1597       replacement =
1598           builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt));
1599     } else {
1600       // This operand is neither a dimension nor a symbol. Add it as a new
1601       // symbol.
1602       replacement = builder.getAffineSymbolExpr(numSymbols++);
1603       if (newSyms)
1604         newSyms->push_back(operand.value());
1605     }
1606     // Add to corresponding replacements vector.
1607     if (operand.index() < map.getNumDims()) {
1608       dimReplacements[operand.index()] = replacement;
1609     } else {
1610       symReplacements[operand.index() - map.getNumDims()] = replacement;
1611     }
1612   }
1613 
1614   return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
1615                                    dims.size(), numSymbols);
1616 }
1617 
1618 FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const {
1619   FlatAffineValueConstraints domain = *this;
1620   // Convert all range variables to local variables.
1621   domain.convertDimToLocal(getNumDomainDims(),
1622                            getNumDomainDims() + getNumRangeDims());
1623   return domain;
1624 }
1625 
1626 FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const {
1627   FlatAffineValueConstraints range = *this;
1628   // Convert all domain variables to local variables.
1629   range.convertDimToLocal(0, getNumDomainDims());
1630   return range;
1631 }
1632 
1633 void FlatAffineRelation::compose(const FlatAffineRelation &other) {
1634   assert(getNumDomainDims() == other.getNumRangeDims() &&
1635          "Domain of this and range of other do not match");
1636   assert(std::equal(values.begin(), values.begin() + getNumDomainDims(),
1637                     other.values.begin() + other.getNumDomainDims()) &&
1638          "Domain of this and range of other do not match");
1639 
1640   FlatAffineRelation rel = other;
1641 
1642   // Convert `rel` from
1643   //    [otherDomain] -> [otherRange]
1644   // to
1645   //    [otherDomain] -> [otherRange thisRange]
1646   // and `this` from
1647   //    [thisDomain] -> [thisRange]
1648   // to
1649   //    [otherDomain thisDomain] -> [thisRange].
1650   unsigned removeDims = rel.getNumRangeDims();
1651   insertDomainId(0, rel.getNumDomainDims());
1652   rel.appendRangeId(getNumRangeDims());
1653 
1654   // Merge symbol and local identifiers.
1655   mergeSymbolIds(rel);
1656   mergeLocalIds(rel);
1657 
1658   // Convert `rel` from [otherDomain] -> [otherRange thisRange] to
1659   // [otherDomain] -> [thisRange] by converting first otherRange range ids
1660   // to local ids.
1661   rel.convertDimToLocal(rel.getNumDomainDims(),
1662                         rel.getNumDomainDims() + removeDims);
1663   // Convert `this` from [otherDomain thisDomain] -> [thisRange] to
1664   // [otherDomain] -> [thisRange] by converting last thisDomain domain ids
1665   // to local ids.
1666   convertDimToLocal(getNumDomainDims() - removeDims, getNumDomainDims());
1667 
1668   auto thisMaybeValues = getMaybeDimValues();
1669   auto relMaybeValues = rel.getMaybeDimValues();
1670 
1671   // Add and match domain of `rel` to domain of `this`.
1672   for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
1673     if (relMaybeValues[i].hasValue())
1674       setValue(i, relMaybeValues[i].getValue());
1675   // Add and match range of `this` to range of `rel`.
1676   for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) {
1677     unsigned rangeIdx = rel.getNumDomainDims() + i;
1678     if (thisMaybeValues[rangeIdx].hasValue())
1679       rel.setValue(rangeIdx, thisMaybeValues[rangeIdx].getValue());
1680   }
1681 
1682   // Append `this` to `rel` and simplify constraints.
1683   rel.append(*this);
1684   rel.removeRedundantLocalVars();
1685 
1686   *this = rel;
1687 }
1688 
1689 void FlatAffineRelation::inverse() {
1690   unsigned oldDomain = getNumDomainDims();
1691   unsigned oldRange = getNumRangeDims();
1692   // Add new range ids.
1693   appendRangeId(oldDomain);
1694   // Swap new ids with domain.
1695   for (unsigned i = 0; i < oldDomain; ++i)
1696     swapId(i, oldDomain + oldRange + i);
1697   // Remove the swapped domain.
1698   removeIdRange(0, oldDomain);
1699   // Set domain and range as inverse.
1700   numDomainDims = oldRange;
1701   numRangeDims = oldDomain;
1702 }
1703 
1704 void FlatAffineRelation::insertDomainId(unsigned pos, unsigned num) {
1705   assert(pos <= getNumDomainDims() &&
1706          "Id cannot be inserted at invalid position");
1707   insertDimId(pos, num);
1708   numDomainDims += num;
1709 }
1710 
1711 void FlatAffineRelation::insertRangeId(unsigned pos, unsigned num) {
1712   assert(pos <= getNumRangeDims() &&
1713          "Id cannot be inserted at invalid position");
1714   insertDimId(getNumDomainDims() + pos, num);
1715   numRangeDims += num;
1716 }
1717 
1718 void FlatAffineRelation::appendDomainId(unsigned num) {
1719   insertDimId(getNumDomainDims(), num);
1720   numDomainDims += num;
1721 }
1722 
1723 void FlatAffineRelation::appendRangeId(unsigned num) {
1724   insertDimId(getNumDimIds(), num);
1725   numRangeDims += num;
1726 }
1727 
1728 void FlatAffineRelation::removeIdRange(IdKind kind, unsigned idStart,
1729                                        unsigned idLimit) {
1730   assert(idLimit <= getNumIdKind(kind));
1731   if (idStart >= idLimit)
1732     return;
1733 
1734   FlatAffineValueConstraints::removeIdRange(kind, idStart, idLimit);
1735 
1736   // If kind is not SetDim, domain and range don't need to be updated.
1737   if (kind != IdKind::SetDim)
1738     return;
1739 
1740   // Compute number of domain and range identifiers to remove. This is done by
1741   // intersecting the range of domain/range ids with range of ids to remove.
1742   unsigned intersectDomainLHS = std::min(idLimit, getNumDomainDims());
1743   unsigned intersectDomainRHS = idStart;
1744   unsigned intersectRangeLHS = std::min(idLimit, getNumDimIds());
1745   unsigned intersectRangeRHS = std::max(idStart, getNumDomainDims());
1746 
1747   if (intersectDomainLHS > intersectDomainRHS)
1748     numDomainDims -= intersectDomainLHS - intersectDomainRHS;
1749   if (intersectRangeLHS > intersectRangeRHS)
1750     numRangeDims -= intersectRangeLHS - intersectRangeRHS;
1751 }
1752 
1753 LogicalResult mlir::getRelationFromMap(AffineMap &map,
1754                                        FlatAffineRelation &rel) {
1755   // Get flattened affine expressions.
1756   std::vector<SmallVector<int64_t, 8>> flatExprs;
1757   FlatAffineConstraints localVarCst;
1758   if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst)))
1759     return failure();
1760 
1761   unsigned oldDimNum = localVarCst.getNumDimIds();
1762   unsigned oldCols = localVarCst.getNumCols();
1763   unsigned numRangeIds = map.getNumResults();
1764   unsigned numDomainIds = map.getNumDims();
1765 
1766   // Add range as the new expressions.
1767   localVarCst.appendDimId(numRangeIds);
1768 
1769   // Add equalities between source and range.
1770   SmallVector<int64_t, 8> eq(localVarCst.getNumCols());
1771   for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1772     // Zero fill.
1773     std::fill(eq.begin(), eq.end(), 0);
1774     // Fill equality.
1775     for (unsigned j = 0, f = oldDimNum; j < f; ++j)
1776       eq[j] = flatExprs[i][j];
1777     for (unsigned j = oldDimNum, f = oldCols; j < f; ++j)
1778       eq[j + numRangeIds] = flatExprs[i][j];
1779     // Set this dimension to -1 to equate lhs and rhs and add equality.
1780     eq[numDomainIds + i] = -1;
1781     localVarCst.addEquality(eq);
1782   }
1783 
1784   // Create relation and return success.
1785   rel = FlatAffineRelation(numDomainIds, numRangeIds, localVarCst);
1786   return success();
1787 }
1788 
1789 LogicalResult mlir::getRelationFromMap(const AffineValueMap &map,
1790                                        FlatAffineRelation &rel) {
1791 
1792   AffineMap affineMap = map.getAffineMap();
1793   if (failed(getRelationFromMap(affineMap, rel)))
1794     return failure();
1795 
1796   // Set symbol values for domain dimensions and symbols.
1797   for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
1798     rel.setValue(i, map.getOperand(i));
1799   for (unsigned i = rel.getNumDimIds(), e = rel.getNumDimAndSymbolIds(); i < e;
1800        ++i)
1801     rel.setValue(i, map.getOperand(i - rel.getNumRangeDims()));
1802 
1803   return success();
1804 }
1805