1 //===- AffineStructures.cpp - MLIR Affine Structures Class-----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Structures for affine/polyhedral analysis of affine dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
14 #include "mlir/Analysis/Presburger/LinearTransform.h"
15 #include "mlir/Analysis/Presburger/Simplex.h"
16 #include "mlir/Analysis/Presburger/Utils.h"
17 #include "mlir/Dialect/Affine/IR/AffineOps.h"
18 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
19 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
20 #include "mlir/IR/AffineExprVisitor.h"
21 #include "mlir/IR/IntegerSet.h"
22 #include "mlir/Support/LLVM.h"
23 #include "mlir/Support/MathExtras.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/SmallPtrSet.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 #define DEBUG_TYPE "affine-structures"
31 
32 using namespace mlir;
33 using namespace presburger;
34 
35 namespace {
36 
37 // See comments for SimpleAffineExprFlattener.
38 // An AffineExprFlattener extends a SimpleAffineExprFlattener by recording
39 // constraint information associated with mod's, floordiv's, and ceildiv's
40 // in FlatAffineValueConstraints 'localVarCst'.
41 struct AffineExprFlattener : public SimpleAffineExprFlattener {
42 public:
43   // Constraints connecting newly introduced local variables (for mod's and
44   // div's) to existing (dimensional and symbolic) ones. These are always
45   // inequalities.
46   IntegerPolyhedron localVarCst;
47 
48   AffineExprFlattener(unsigned nDims, unsigned nSymbols)
49       : SimpleAffineExprFlattener(nDims, nSymbols),
50         localVarCst(PresburgerSpace::getSetSpace(nDims, nSymbols)) {}
51 
52 private:
53   // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
54   // The local identifier added is always a floordiv of a pure add/mul affine
55   // function of other identifiers, coefficients of which are specified in
56   // `dividend' and with respect to the positive constant `divisor'. localExpr
57   // is the simplified tree expression (AffineExpr) corresponding to the
58   // quantifier.
59   void addLocalFloorDivId(ArrayRef<int64_t> dividend, int64_t divisor,
60                           AffineExpr localExpr) override {
61     SimpleAffineExprFlattener::addLocalFloorDivId(dividend, divisor, localExpr);
62     // Update localVarCst.
63     localVarCst.addLocalFloorDiv(dividend, divisor);
64   }
65 };
66 
67 } // namespace
68 
69 // Flattens the expressions in map. Returns failure if 'expr' was unable to be
70 // flattened (i.e., semi-affine expressions not handled yet).
71 static LogicalResult
72 getFlattenedAffineExprs(ArrayRef<AffineExpr> exprs, unsigned numDims,
73                         unsigned numSymbols,
74                         std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
75                         FlatAffineValueConstraints *localVarCst) {
76   if (exprs.empty()) {
77     localVarCst->reset(numDims, numSymbols);
78     return success();
79   }
80 
81   AffineExprFlattener flattener(numDims, numSymbols);
82   // Use the same flattener to simplify each expression successively. This way
83   // local identifiers / expressions are shared.
84   for (auto expr : exprs) {
85     if (!expr.isPureAffine())
86       return failure();
87 
88     flattener.walkPostOrder(expr);
89   }
90 
91   assert(flattener.operandExprStack.size() == exprs.size());
92   flattenedExprs->clear();
93   flattenedExprs->assign(flattener.operandExprStack.begin(),
94                          flattener.operandExprStack.end());
95 
96   if (localVarCst)
97     localVarCst->clearAndCopyFrom(flattener.localVarCst);
98 
99   return success();
100 }
101 
102 // Flattens 'expr' into 'flattenedExpr'. Returns failure if 'expr' was unable to
103 // be flattened (semi-affine expressions not handled yet).
104 LogicalResult
105 mlir::getFlattenedAffineExpr(AffineExpr expr, unsigned numDims,
106                              unsigned numSymbols,
107                              SmallVectorImpl<int64_t> *flattenedExpr,
108                              FlatAffineValueConstraints *localVarCst) {
109   std::vector<SmallVector<int64_t, 8>> flattenedExprs;
110   LogicalResult ret = ::getFlattenedAffineExprs({expr}, numDims, numSymbols,
111                                                 &flattenedExprs, localVarCst);
112   *flattenedExpr = flattenedExprs[0];
113   return ret;
114 }
115 
116 /// Flattens the expressions in map. Returns failure if 'expr' was unable to be
117 /// flattened (i.e., semi-affine expressions not handled yet).
118 LogicalResult mlir::getFlattenedAffineExprs(
119     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
120     FlatAffineValueConstraints *localVarCst) {
121   if (map.getNumResults() == 0) {
122     localVarCst->reset(map.getNumDims(), map.getNumSymbols());
123     return success();
124   }
125   return ::getFlattenedAffineExprs(map.getResults(), map.getNumDims(),
126                                    map.getNumSymbols(), flattenedExprs,
127                                    localVarCst);
128 }
129 
130 LogicalResult mlir::getFlattenedAffineExprs(
131     IntegerSet set, std::vector<SmallVector<int64_t, 8>> *flattenedExprs,
132     FlatAffineValueConstraints *localVarCst) {
133   if (set.getNumConstraints() == 0) {
134     localVarCst->reset(set.getNumDims(), set.getNumSymbols());
135     return success();
136   }
137   return ::getFlattenedAffineExprs(set.getConstraints(), set.getNumDims(),
138                                    set.getNumSymbols(), flattenedExprs,
139                                    localVarCst);
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // FlatAffineConstraints / FlatAffineValueConstraints.
144 //===----------------------------------------------------------------------===//
145 
146 std::unique_ptr<FlatAffineValueConstraints>
147 FlatAffineValueConstraints::clone() const {
148   return std::make_unique<FlatAffineValueConstraints>(*this);
149 }
150 
151 // Construct from an IntegerSet.
152 FlatAffineValueConstraints::FlatAffineValueConstraints(IntegerSet set)
153     : IntegerPolyhedron(set.getNumInequalities(), set.getNumEqualities(),
154                         set.getNumDims() + set.getNumSymbols() + 1,
155                         PresburgerSpace::getSetSpace(set.getNumDims(),
156                                                      set.getNumSymbols(),
157                                                      /*numLocals=*/0)) {
158 
159   // Resize values.
160   values.resize(getNumIds(), None);
161 
162   // Flatten expressions and add them to the constraint system.
163   std::vector<SmallVector<int64_t, 8>> flatExprs;
164   FlatAffineValueConstraints localVarCst;
165   if (failed(getFlattenedAffineExprs(set, &flatExprs, &localVarCst))) {
166     assert(false && "flattening unimplemented for semi-affine integer sets");
167     return;
168   }
169   assert(flatExprs.size() == set.getNumConstraints());
170   insertId(IdKind::Local, getNumIdKind(IdKind::Local),
171            /*num=*/localVarCst.getNumLocalIds());
172 
173   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
174     const auto &flatExpr = flatExprs[i];
175     assert(flatExpr.size() == getNumCols());
176     if (set.getEqFlags()[i]) {
177       addEquality(flatExpr);
178     } else {
179       addInequality(flatExpr);
180     }
181   }
182   // Add the other constraints involving local id's from flattening.
183   append(localVarCst);
184 }
185 
186 // Construct a hyperrectangular constraint set from ValueRanges that represent
187 // induction variables, lower and upper bounds. `ivs`, `lbs` and `ubs` are
188 // expected to match one to one. The order of variables and constraints is:
189 //
190 // ivs | lbs | ubs | eq/ineq
191 // ----+-----+-----+---------
192 //   1   -1     0      >= 0
193 // ----+-----+-----+---------
194 //  -1    0     1      >= 0
195 //
196 // All dimensions as set as DimId.
197 FlatAffineValueConstraints
198 FlatAffineValueConstraints::getHyperrectangular(ValueRange ivs, ValueRange lbs,
199                                                 ValueRange ubs) {
200   FlatAffineValueConstraints res;
201   unsigned nIvs = ivs.size();
202   assert(nIvs == lbs.size() && "expected as many lower bounds as ivs");
203   assert(nIvs == ubs.size() && "expected as many upper bounds as ivs");
204 
205   if (nIvs == 0)
206     return res;
207 
208   res.appendDimId(ivs);
209   unsigned lbsStart = res.appendDimId(lbs);
210   unsigned ubsStart = res.appendDimId(ubs);
211 
212   MLIRContext *ctx = ivs.front().getContext();
213   for (int ivIdx = 0, e = nIvs; ivIdx < e; ++ivIdx) {
214     // iv - lb >= 0
215     AffineMap lb = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
216                                   getAffineDimExpr(lbsStart + ivIdx, ctx));
217     if (failed(res.addBound(BoundType::LB, ivIdx, lb)))
218       llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
219     // -iv + ub >= 0
220     AffineMap ub = AffineMap::get(/*dimCount=*/3 * nIvs, /*symbolCount=*/0,
221                                   getAffineDimExpr(ubsStart + ivIdx, ctx));
222     if (failed(res.addBound(BoundType::UB, ivIdx, ub)))
223       llvm_unreachable("Unexpected FlatAffineValueConstraints creation error");
224   }
225   return res;
226 }
227 
228 void FlatAffineValueConstraints::reset(unsigned numReservedInequalities,
229                                        unsigned numReservedEqualities,
230                                        unsigned newNumReservedCols,
231                                        unsigned newNumDims,
232                                        unsigned newNumSymbols,
233                                        unsigned newNumLocals) {
234   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
235          "minimum 1 column");
236   *this = FlatAffineValueConstraints(numReservedInequalities,
237                                      numReservedEqualities, newNumReservedCols,
238                                      newNumDims, newNumSymbols, newNumLocals);
239 }
240 
241 void FlatAffineValueConstraints::reset(unsigned newNumDims,
242                                        unsigned newNumSymbols,
243                                        unsigned newNumLocals) {
244   reset(/*numReservedInequalities=*/0, /*numReservedEqualities=*/0,
245         /*numReservedCols=*/newNumDims + newNumSymbols + newNumLocals + 1,
246         newNumDims, newNumSymbols, newNumLocals);
247 }
248 
249 void FlatAffineValueConstraints::reset(
250     unsigned numReservedInequalities, unsigned numReservedEqualities,
251     unsigned newNumReservedCols, unsigned newNumDims, unsigned newNumSymbols,
252     unsigned newNumLocals, ArrayRef<Value> valArgs) {
253   assert(newNumReservedCols >= newNumDims + newNumSymbols + newNumLocals + 1 &&
254          "minimum 1 column");
255   SmallVector<Optional<Value>, 8> newVals;
256   if (!valArgs.empty())
257     newVals.assign(valArgs.begin(), valArgs.end());
258 
259   *this = FlatAffineValueConstraints(
260       numReservedInequalities, numReservedEqualities, newNumReservedCols,
261       newNumDims, newNumSymbols, newNumLocals, newVals);
262 }
263 
264 void FlatAffineValueConstraints::reset(unsigned newNumDims,
265                                        unsigned newNumSymbols,
266                                        unsigned newNumLocals,
267                                        ArrayRef<Value> valArgs) {
268   reset(0, 0, newNumDims + newNumSymbols + newNumLocals + 1, newNumDims,
269         newNumSymbols, newNumLocals, valArgs);
270 }
271 
272 unsigned FlatAffineValueConstraints::appendDimId(ValueRange vals) {
273   unsigned pos = getNumDimIds();
274   insertId(IdKind::SetDim, pos, vals);
275   return pos;
276 }
277 
278 unsigned FlatAffineValueConstraints::appendSymbolId(ValueRange vals) {
279   unsigned pos = getNumSymbolIds();
280   insertId(IdKind::Symbol, pos, vals);
281   return pos;
282 }
283 
284 unsigned FlatAffineValueConstraints::insertDimId(unsigned pos,
285                                                  ValueRange vals) {
286   return insertId(IdKind::SetDim, pos, vals);
287 }
288 
289 unsigned FlatAffineValueConstraints::insertSymbolId(unsigned pos,
290                                                     ValueRange vals) {
291   return insertId(IdKind::Symbol, pos, vals);
292 }
293 
294 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
295                                               unsigned num) {
296   unsigned absolutePos = IntegerPolyhedron::insertId(kind, pos, num);
297   values.insert(values.begin() + absolutePos, num, None);
298   assert(values.size() == getNumIds());
299   return absolutePos;
300 }
301 
302 unsigned FlatAffineValueConstraints::insertId(IdKind kind, unsigned pos,
303                                               ValueRange vals) {
304   assert(!vals.empty() && "expected ValueRange with Values");
305   unsigned num = vals.size();
306   unsigned absolutePos = IntegerPolyhedron::insertId(kind, pos, num);
307 
308   // If a Value is provided, insert it; otherwise use None.
309   for (unsigned i = 0; i < num; ++i)
310     values.insert(values.begin() + absolutePos + i,
311                   vals[i] ? Optional<Value>(vals[i]) : None);
312 
313   assert(values.size() == getNumIds());
314   return absolutePos;
315 }
316 
317 bool FlatAffineValueConstraints::hasValues() const {
318   return llvm::find_if(values, [](Optional<Value> id) {
319            return id.hasValue();
320          }) != values.end();
321 }
322 
323 /// Checks if two constraint systems are in the same space, i.e., if they are
324 /// associated with the same set of identifiers, appearing in the same order.
325 static bool areIdsAligned(const FlatAffineValueConstraints &a,
326                           const FlatAffineValueConstraints &b) {
327   return a.getNumDimIds() == b.getNumDimIds() &&
328          a.getNumSymbolIds() == b.getNumSymbolIds() &&
329          a.getNumIds() == b.getNumIds() &&
330          a.getMaybeValues().equals(b.getMaybeValues());
331 }
332 
333 /// Calls areIdsAligned to check if two constraint systems have the same set
334 /// of identifiers in the same order.
335 bool FlatAffineValueConstraints::areIdsAlignedWithOther(
336     const FlatAffineValueConstraints &other) {
337   return areIdsAligned(*this, other);
338 }
339 
340 /// Checks if the SSA values associated with `cst`'s identifiers in range
341 /// [start, end) are unique.
342 static bool LLVM_ATTRIBUTE_UNUSED areIdsUnique(
343     const FlatAffineValueConstraints &cst, unsigned start, unsigned end) {
344 
345   assert(start <= cst.getNumIds() && "Start position out of bounds");
346   assert(end <= cst.getNumIds() && "End position out of bounds");
347 
348   if (start >= end)
349     return true;
350 
351   SmallPtrSet<Value, 8> uniqueIds;
352   ArrayRef<Optional<Value>> maybeValues = cst.getMaybeValues();
353   for (Optional<Value> val : maybeValues) {
354     if (val.hasValue() && !uniqueIds.insert(val.getValue()).second)
355       return false;
356   }
357   return true;
358 }
359 
360 /// Checks if the SSA values associated with `cst`'s identifiers are unique.
361 static bool LLVM_ATTRIBUTE_UNUSED
362 areIdsUnique(const FlatAffineValueConstraints &cst) {
363   return areIdsUnique(cst, 0, cst.getNumIds());
364 }
365 
366 /// Checks if the SSA values associated with `cst`'s identifiers of kind `kind`
367 /// are unique.
368 static bool LLVM_ATTRIBUTE_UNUSED
369 areIdsUnique(const FlatAffineValueConstraints &cst, IdKind kind) {
370 
371   if (kind == IdKind::SetDim)
372     return areIdsUnique(cst, 0, cst.getNumDimIds());
373   if (kind == IdKind::Symbol)
374     return areIdsUnique(cst, cst.getNumDimIds(), cst.getNumDimAndSymbolIds());
375   if (kind == IdKind::Local)
376     return areIdsUnique(cst, cst.getNumDimAndSymbolIds(), cst.getNumIds());
377   llvm_unreachable("Unexpected IdKind");
378 }
379 
380 /// Merge and align the identifiers of A and B starting at 'offset', so that
381 /// both constraint systems get the union of the contained identifiers that is
382 /// dimension-wise and symbol-wise unique; both constraint systems are updated
383 /// so that they have the union of all identifiers, with A's original
384 /// identifiers appearing first followed by any of B's identifiers that didn't
385 /// appear in A. Local identifiers in B that have the same division
386 /// representation as local identifiers in A are merged into one.
387 //  E.g.: Input: A has ((%i, %j) [%M, %N]) and B has (%k, %j) [%P, %N, %M])
388 //        Output: both A, B have (%i, %j, %k) [%M, %N, %P]
389 static void mergeAndAlignIds(unsigned offset, FlatAffineValueConstraints *a,
390                              FlatAffineValueConstraints *b) {
391   assert(offset <= a->getNumDimIds() && offset <= b->getNumDimIds());
392   // A merge/align isn't meaningful if a cst's ids aren't distinct.
393   assert(areIdsUnique(*a) && "A's values aren't unique");
394   assert(areIdsUnique(*b) && "B's values aren't unique");
395 
396   assert(std::all_of(a->getMaybeValues().begin() + offset,
397                      a->getMaybeValues().begin() + a->getNumDimAndSymbolIds(),
398                      [](Optional<Value> id) { return id.hasValue(); }));
399 
400   assert(std::all_of(b->getMaybeValues().begin() + offset,
401                      b->getMaybeValues().begin() + b->getNumDimAndSymbolIds(),
402                      [](Optional<Value> id) { return id.hasValue(); }));
403 
404   SmallVector<Value, 4> aDimValues;
405   a->getValues(offset, a->getNumDimIds(), &aDimValues);
406 
407   {
408     // Merge dims from A into B.
409     unsigned d = offset;
410     for (auto aDimValue : aDimValues) {
411       unsigned loc;
412       if (b->findId(aDimValue, &loc)) {
413         assert(loc >= offset && "A's dim appears in B's aligned range");
414         assert(loc < b->getNumDimIds() &&
415                "A's dim appears in B's non-dim position");
416         b->swapId(d, loc);
417       } else {
418         b->insertDimId(d, aDimValue);
419       }
420       d++;
421     }
422     // Dimensions that are in B, but not in A, are added at the end.
423     for (unsigned t = a->getNumDimIds(), e = b->getNumDimIds(); t < e; t++) {
424       a->appendDimId(b->getValue(t));
425     }
426     assert(a->getNumDimIds() == b->getNumDimIds() &&
427            "expected same number of dims");
428   }
429 
430   // Merge and align symbols of A and B
431   a->mergeSymbolIds(*b);
432   // Merge and align local ids of A and B
433   a->mergeLocalIds(*b);
434 
435   assert(areIdsAligned(*a, *b) && "IDs expected to be aligned");
436 }
437 
438 // Call 'mergeAndAlignIds' to align constraint systems of 'this' and 'other'.
439 void FlatAffineValueConstraints::mergeAndAlignIdsWithOther(
440     unsigned offset, FlatAffineValueConstraints *other) {
441   mergeAndAlignIds(offset, this, other);
442 }
443 
444 LogicalResult
445 FlatAffineValueConstraints::composeMap(const AffineValueMap *vMap) {
446   return composeMatchingMap(
447       computeAlignedMap(vMap->getAffineMap(), vMap->getOperands()));
448 }
449 
450 // Similar to `composeMap` except that no Values need be associated with the
451 // constraint system nor are they looked at -- the dimensions and symbols of
452 // `other` are expected to correspond 1:1 to `this` system.
453 LogicalResult FlatAffineValueConstraints::composeMatchingMap(AffineMap other) {
454   assert(other.getNumDims() == getNumDimIds() && "dim mismatch");
455   assert(other.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
456 
457   std::vector<SmallVector<int64_t, 8>> flatExprs;
458   if (failed(flattenAlignedMapAndMergeLocals(other, &flatExprs)))
459     return failure();
460   assert(flatExprs.size() == other.getNumResults());
461 
462   // Add dimensions corresponding to the map's results.
463   insertDimId(/*pos=*/0, /*num=*/other.getNumResults());
464 
465   // We add one equality for each result connecting the result dim of the map to
466   // the other identifiers.
467   // E.g.: if the expression is 16*i0 + i1, and this is the r^th
468   // iteration/result of the value map, we are adding the equality:
469   // d_r - 16*i0 - i1 = 0. Similarly, when flattening (i0 + 1, i0 + 8*i2), we
470   // add two equalities: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
471   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
472     const auto &flatExpr = flatExprs[r];
473     assert(flatExpr.size() >= other.getNumInputs() + 1);
474 
475     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
476     // Set the coefficient for this result to one.
477     eqToAdd[r] = 1;
478 
479     // Dims and symbols.
480     for (unsigned i = 0, f = other.getNumInputs(); i < f; i++) {
481       // Negate `eq[r]` since the newly added dimension will be set to this one.
482       eqToAdd[e + i] = -flatExpr[i];
483     }
484     // Local columns of `eq` are at the beginning.
485     unsigned j = getNumDimIds() + getNumSymbolIds();
486     unsigned end = flatExpr.size() - 1;
487     for (unsigned i = other.getNumInputs(); i < end; i++, j++) {
488       eqToAdd[j] = -flatExpr[i];
489     }
490 
491     // Constant term.
492     eqToAdd[getNumCols() - 1] = -flatExpr[flatExpr.size() - 1];
493 
494     // Add the equality connecting the result of the map to this constraint set.
495     addEquality(eqToAdd);
496   }
497 
498   return success();
499 }
500 
501 // Turn a symbol into a dimension.
502 static void turnSymbolIntoDim(FlatAffineValueConstraints *cst, Value id) {
503   unsigned pos;
504   if (cst->findId(id, &pos) && pos >= cst->getNumDimIds() &&
505       pos < cst->getNumDimAndSymbolIds()) {
506     cst->swapId(pos, cst->getNumDimIds());
507     cst->setDimSymbolSeparation(cst->getNumSymbolIds() - 1);
508   }
509 }
510 
511 /// Merge and align symbols of `this` and `other` such that both get union of
512 /// of symbols that are unique. Symbols in `this` and `other` should be
513 /// unique. Symbols with Value as `None` are considered to be inequal to all
514 /// other symbols.
515 void FlatAffineValueConstraints::mergeSymbolIds(
516     FlatAffineValueConstraints &other) {
517 
518   assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique");
519   assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique");
520 
521   SmallVector<Value, 4> aSymValues;
522   getValues(getNumDimIds(), getNumDimAndSymbolIds(), &aSymValues);
523 
524   // Merge symbols: merge symbols into `other` first from `this`.
525   unsigned s = other.getNumDimIds();
526   for (Value aSymValue : aSymValues) {
527     unsigned loc;
528     // If the id is a symbol in `other`, then align it, otherwise assume that
529     // it is a new symbol
530     if (other.findId(aSymValue, &loc) && loc >= other.getNumDimIds() &&
531         loc < other.getNumDimAndSymbolIds())
532       other.swapId(s, loc);
533     else
534       other.insertSymbolId(s - other.getNumDimIds(), aSymValue);
535     s++;
536   }
537 
538   // Symbols that are in other, but not in this, are added at the end.
539   for (unsigned t = other.getNumDimIds() + getNumSymbolIds(),
540                 e = other.getNumDimAndSymbolIds();
541        t < e; t++)
542     insertSymbolId(getNumSymbolIds(), other.getValue(t));
543 
544   assert(getNumSymbolIds() == other.getNumSymbolIds() &&
545          "expected same number of symbols");
546   assert(areIdsUnique(*this, IdKind::Symbol) && "Symbol ids are not unique");
547   assert(areIdsUnique(other, IdKind::Symbol) && "Symbol ids are not unique");
548 }
549 
550 // Changes all symbol identifiers which are loop IVs to dim identifiers.
551 void FlatAffineValueConstraints::convertLoopIVSymbolsToDims() {
552   // Gather all symbols which are loop IVs.
553   SmallVector<Value, 4> loopIVs;
554   for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++) {
555     if (hasValue(i) && getForInductionVarOwner(getValue(i)))
556       loopIVs.push_back(getValue(i));
557   }
558   // Turn each symbol in 'loopIVs' into a dim identifier.
559   for (auto iv : loopIVs) {
560     turnSymbolIntoDim(this, iv);
561   }
562 }
563 
564 void FlatAffineValueConstraints::addInductionVarOrTerminalSymbol(Value val) {
565   if (containsId(val))
566     return;
567 
568   // Caller is expected to fully compose map/operands if necessary.
569   assert((isTopLevelValue(val) || isForInductionVar(val)) &&
570          "non-terminal symbol / loop IV expected");
571   // Outer loop IVs could be used in forOp's bounds.
572   if (auto loop = getForInductionVarOwner(val)) {
573     appendDimId(val);
574     if (failed(this->addAffineForOpDomain(loop)))
575       LLVM_DEBUG(
576           loop.emitWarning("failed to add domain info to constraint system"));
577     return;
578   }
579   // Add top level symbol.
580   appendSymbolId(val);
581   // Check if the symbol is a constant.
582   if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>())
583     addBound(BoundType::EQ, val, constOp.value());
584 }
585 
586 LogicalResult
587 FlatAffineValueConstraints::addAffineForOpDomain(AffineForOp forOp) {
588   unsigned pos;
589   // Pre-condition for this method.
590   if (!findId(forOp.getInductionVar(), &pos)) {
591     assert(false && "Value not found");
592     return failure();
593   }
594 
595   int64_t step = forOp.getStep();
596   if (step != 1) {
597     if (!forOp.hasConstantLowerBound())
598       LLVM_DEBUG(forOp.emitWarning("domain conservatively approximated"));
599     else {
600       // Add constraints for the stride.
601       // (iv - lb) % step = 0 can be written as:
602       // (iv - lb) - step * q = 0 where q = (iv - lb) / step.
603       // Add local variable 'q' and add the above equality.
604       // The first constraint is q = (iv - lb) floordiv step
605       SmallVector<int64_t, 8> dividend(getNumCols(), 0);
606       int64_t lb = forOp.getConstantLowerBound();
607       dividend[pos] = 1;
608       dividend.back() -= lb;
609       addLocalFloorDiv(dividend, step);
610       // Second constraint: (iv - lb) - step * q = 0.
611       SmallVector<int64_t, 8> eq(getNumCols(), 0);
612       eq[pos] = 1;
613       eq.back() -= lb;
614       // For the local var just added above.
615       eq[getNumCols() - 2] = -step;
616       addEquality(eq);
617     }
618   }
619 
620   if (forOp.hasConstantLowerBound()) {
621     addBound(BoundType::LB, pos, forOp.getConstantLowerBound());
622   } else {
623     // Non-constant lower bound case.
624     if (failed(addBound(BoundType::LB, pos, forOp.getLowerBoundMap(),
625                         forOp.getLowerBoundOperands())))
626       return failure();
627   }
628 
629   if (forOp.hasConstantUpperBound()) {
630     addBound(BoundType::UB, pos, forOp.getConstantUpperBound() - 1);
631     return success();
632   }
633   // Non-constant upper bound case.
634   return addBound(BoundType::UB, pos, forOp.getUpperBoundMap(),
635                   forOp.getUpperBoundOperands());
636 }
637 
638 LogicalResult
639 FlatAffineValueConstraints::addDomainFromSliceMaps(ArrayRef<AffineMap> lbMaps,
640                                                    ArrayRef<AffineMap> ubMaps,
641                                                    ArrayRef<Value> operands) {
642   assert(lbMaps.size() == ubMaps.size());
643   assert(lbMaps.size() <= getNumDimIds());
644 
645   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
646     AffineMap lbMap = lbMaps[i];
647     AffineMap ubMap = ubMaps[i];
648     assert(!lbMap || lbMap.getNumInputs() == operands.size());
649     assert(!ubMap || ubMap.getNumInputs() == operands.size());
650 
651     // Check if this slice is just an equality along this dimension. If so,
652     // retrieve the existing loop it equates to and add it to the system.
653     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
654         ubMap.getNumResults() == 1 &&
655         lbMap.getResult(0) + 1 == ubMap.getResult(0) &&
656         // The condition above will be true for maps describing a single
657         // iteration (e.g., lbMap.getResult(0) = 0, ubMap.getResult(0) = 1).
658         // Make sure we skip those cases by checking that the lb result is not
659         // just a constant.
660         !lbMap.getResult(0).isa<AffineConstantExpr>()) {
661       // Limited support: we expect the lb result to be just a loop dimension.
662       // Not supported otherwise for now.
663       AffineDimExpr result = lbMap.getResult(0).dyn_cast<AffineDimExpr>();
664       if (!result)
665         return failure();
666 
667       AffineForOp loop =
668           getForInductionVarOwner(operands[result.getPosition()]);
669       if (!loop)
670         return failure();
671 
672       if (failed(addAffineForOpDomain(loop)))
673         return failure();
674       continue;
675     }
676 
677     // This slice refers to a loop that doesn't exist in the IR yet. Add its
678     // bounds to the system assuming its dimension identifier position is the
679     // same as the position of the loop in the loop nest.
680     if (lbMap && failed(addBound(BoundType::LB, i, lbMap, operands)))
681       return failure();
682     if (ubMap && failed(addBound(BoundType::UB, i, ubMap, operands)))
683       return failure();
684   }
685   return success();
686 }
687 
688 void FlatAffineValueConstraints::addAffineIfOpDomain(AffineIfOp ifOp) {
689   // Create the base constraints from the integer set attached to ifOp.
690   FlatAffineValueConstraints cst(ifOp.getIntegerSet());
691 
692   // Bind ids in the constraints to ifOp operands.
693   SmallVector<Value, 4> operands = ifOp.getOperands();
694   cst.setValues(0, cst.getNumDimAndSymbolIds(), operands);
695 
696   // Merge the constraints from ifOp to the current domain. We need first merge
697   // and align the IDs from both constraints, and then append the constraints
698   // from the ifOp into the current one.
699   mergeAndAlignIdsWithOther(0, &cst);
700   append(cst);
701 }
702 
703 bool FlatAffineValueConstraints::hasConsistentState() const {
704   return IntegerPolyhedron::hasConsistentState() &&
705          values.size() == getNumIds();
706 }
707 
708 void FlatAffineValueConstraints::removeIdRange(IdKind kind, unsigned idStart,
709                                                unsigned idLimit) {
710   IntegerPolyhedron::removeIdRange(kind, idStart, idLimit);
711   unsigned offset = getIdKindOffset(kind);
712   values.erase(values.begin() + idStart + offset,
713                values.begin() + idLimit + offset);
714 }
715 
716 // Determine whether the identifier at 'pos' (say id_r) can be expressed as
717 // modulo of another known identifier (say id_n) w.r.t a constant. For example,
718 // if the following constraints hold true:
719 // ```
720 // 0 <= id_r <= divisor - 1
721 // id_n - (divisor * q_expr) = id_r
722 // ```
723 // where `id_n` is a known identifier (called dividend), and `q_expr` is an
724 // `AffineExpr` (called the quotient expression), `id_r` can be written as:
725 //
726 // `id_r = id_n mod divisor`.
727 //
728 // Additionally, in a special case of the above constaints where `q_expr` is an
729 // identifier itself that is not yet known (say `id_q`), it can be written as a
730 // floordiv in the following way:
731 //
732 // `id_q = id_n floordiv divisor`.
733 //
734 // Returns true if the above mod or floordiv are detected, updating 'memo' with
735 // these new expressions. Returns false otherwise.
736 static bool detectAsMod(const FlatAffineValueConstraints &cst, unsigned pos,
737                         int64_t lbConst, int64_t ubConst,
738                         SmallVectorImpl<AffineExpr> &memo,
739                         MLIRContext *context) {
740   assert(pos < cst.getNumIds() && "invalid position");
741 
742   // Check if a divisor satisfying the condition `0 <= id_r <= divisor - 1` can
743   // be determined.
744   if (lbConst != 0 || ubConst < 1)
745     return false;
746   int64_t divisor = ubConst + 1;
747 
748   // Check for the aforementioned conditions in each equality.
749   for (unsigned curEquality = 0, numEqualities = cst.getNumEqualities();
750        curEquality < numEqualities; curEquality++) {
751     int64_t coefficientAtPos = cst.atEq(curEquality, pos);
752     // If current equality does not involve `id_r`, continue to the next
753     // equality.
754     if (coefficientAtPos == 0)
755       continue;
756 
757     // Constant term should be 0 in this equality.
758     if (cst.atEq(curEquality, cst.getNumCols() - 1) != 0)
759       continue;
760 
761     // Traverse through the equality and construct the dividend expression
762     // `dividendExpr`, to contain all the identifiers which are known and are
763     // not divisible by `(coefficientAtPos * divisor)`. Hope here is that the
764     // `dividendExpr` gets simplified into a single identifier `id_n` discussed
765     // above.
766     auto dividendExpr = getAffineConstantExpr(0, context);
767 
768     // Track the terms that go into quotient expression, later used to detect
769     // additional floordiv.
770     unsigned quotientCount = 0;
771     int quotientPosition = -1;
772     int quotientSign = 1;
773 
774     // Consider each term in the current equality.
775     unsigned curId, e;
776     for (curId = 0, e = cst.getNumDimAndSymbolIds(); curId < e; ++curId) {
777       // Ignore id_r.
778       if (curId == pos)
779         continue;
780       int64_t coefficientOfCurId = cst.atEq(curEquality, curId);
781       // Ignore ids that do not contribute to the current equality.
782       if (coefficientOfCurId == 0)
783         continue;
784       // Check if the current id goes into the quotient expression.
785       if (coefficientOfCurId % (divisor * coefficientAtPos) == 0) {
786         quotientCount++;
787         quotientPosition = curId;
788         quotientSign = (coefficientOfCurId * coefficientAtPos) > 0 ? 1 : -1;
789         continue;
790       }
791       // Identifiers that are part of dividendExpr should be known.
792       if (!memo[curId])
793         break;
794       // Append the current identifier to the dividend expression.
795       dividendExpr = dividendExpr + memo[curId] * coefficientOfCurId;
796     }
797 
798     // Can't construct expression as it depends on a yet uncomputed id.
799     if (curId < e)
800       continue;
801 
802     // Express `id_r` in terms of the other ids collected so far.
803     if (coefficientAtPos > 0)
804       dividendExpr = (-dividendExpr).floorDiv(coefficientAtPos);
805     else
806       dividendExpr = dividendExpr.floorDiv(-coefficientAtPos);
807 
808     // Simplify the expression.
809     dividendExpr = simplifyAffineExpr(dividendExpr, cst.getNumDimIds(),
810                                       cst.getNumSymbolIds());
811     // Only if the final dividend expression is just a single id (which we call
812     // `id_n`), we can proceed.
813     // TODO: Handle AffineSymbolExpr as well. There is no reason to restrict it
814     // to dims themselves.
815     auto dimExpr = dividendExpr.dyn_cast<AffineDimExpr>();
816     if (!dimExpr)
817       continue;
818 
819     // Express `id_r` as `id_n % divisor` and store the expression in `memo`.
820     if (quotientCount >= 1) {
821       auto ub = cst.getConstantBound(FlatAffineValueConstraints::BoundType::UB,
822                                      dimExpr.getPosition());
823       // If `id_n` has an upperbound that is less than the divisor, mod can be
824       // eliminated altogether.
825       if (ub.hasValue() && ub.getValue() < divisor)
826         memo[pos] = dimExpr;
827       else
828         memo[pos] = dimExpr % divisor;
829       // If a unique quotient `id_q` was seen, it can be expressed as
830       // `id_n floordiv divisor`.
831       if (quotientCount == 1 && !memo[quotientPosition])
832         memo[quotientPosition] = dimExpr.floorDiv(divisor) * quotientSign;
833 
834       return true;
835     }
836   }
837   return false;
838 }
839 
840 /// Check if the pos^th identifier can be expressed as a floordiv of an affine
841 /// function of other identifiers (where the divisor is a positive constant)
842 /// given the initial set of expressions in `exprs`. If it can be, the
843 /// corresponding position in `exprs` is set as the detected affine expr. For
844 /// eg: 4q <= i + j <= 4q + 3   <=>   q = (i + j) floordiv 4. An equality can
845 /// also yield a floordiv: eg.  4q = i + j <=> q = (i + j) floordiv 4. 32q + 28
846 /// <= i <= 32q + 31 => q = i floordiv 32.
847 static bool detectAsFloorDiv(const FlatAffineValueConstraints &cst,
848                              unsigned pos, MLIRContext *context,
849                              SmallVectorImpl<AffineExpr> &exprs) {
850   assert(pos < cst.getNumIds() && "invalid position");
851 
852   // Get upper-lower bound pair for this variable.
853   SmallVector<bool, 8> foundRepr(cst.getNumIds(), false);
854   for (unsigned i = 0, e = cst.getNumIds(); i < e; ++i)
855     if (exprs[i])
856       foundRepr[i] = true;
857 
858   SmallVector<int64_t, 8> dividend;
859   unsigned divisor;
860   auto ulPair = computeSingleVarRepr(cst, foundRepr, pos, dividend, divisor);
861 
862   // No upper-lower bound pair found for this var.
863   if (ulPair.kind == ReprKind::None || ulPair.kind == ReprKind::Equality)
864     return false;
865 
866   // Construct the dividend expression.
867   auto dividendExpr = getAffineConstantExpr(dividend.back(), context);
868   for (unsigned c = 0, f = cst.getNumIds(); c < f; c++)
869     if (dividend[c] != 0)
870       dividendExpr = dividendExpr + dividend[c] * exprs[c];
871 
872   // Successfully detected the floordiv.
873   exprs[pos] = dividendExpr.floorDiv(divisor);
874   return true;
875 }
876 
877 std::pair<AffineMap, AffineMap>
878 FlatAffineValueConstraints::getLowerAndUpperBound(
879     unsigned pos, unsigned offset, unsigned num, unsigned symStartPos,
880     ArrayRef<AffineExpr> localExprs, MLIRContext *context) const {
881   assert(pos + offset < getNumDimIds() && "invalid dim start pos");
882   assert(symStartPos >= (pos + offset) && "invalid sym start pos");
883   assert(getNumLocalIds() == localExprs.size() &&
884          "incorrect local exprs count");
885 
886   SmallVector<unsigned, 4> lbIndices, ubIndices, eqIndices;
887   getLowerAndUpperBoundIndices(pos + offset, &lbIndices, &ubIndices, &eqIndices,
888                                offset, num);
889 
890   /// Add to 'b' from 'a' in set [0, offset) U [offset + num, symbStartPos).
891   auto addCoeffs = [&](ArrayRef<int64_t> a, SmallVectorImpl<int64_t> &b) {
892     b.clear();
893     for (unsigned i = 0, e = a.size(); i < e; ++i) {
894       if (i < offset || i >= offset + num)
895         b.push_back(a[i]);
896     }
897   };
898 
899   SmallVector<int64_t, 8> lb, ub;
900   SmallVector<AffineExpr, 4> lbExprs;
901   unsigned dimCount = symStartPos - num;
902   unsigned symCount = getNumDimAndSymbolIds() - symStartPos;
903   lbExprs.reserve(lbIndices.size() + eqIndices.size());
904   // Lower bound expressions.
905   for (auto idx : lbIndices) {
906     auto ineq = getInequality(idx);
907     // Extract the lower bound (in terms of other coeff's + const), i.e., if
908     // i - j + 1 >= 0 is the constraint, 'pos' is for i the lower bound is j
909     // - 1.
910     addCoeffs(ineq, lb);
911     std::transform(lb.begin(), lb.end(), lb.begin(), std::negate<int64_t>());
912     auto expr =
913         getAffineExprFromFlatForm(lb, dimCount, symCount, localExprs, context);
914     // expr ceildiv divisor is (expr + divisor - 1) floordiv divisor
915     int64_t divisor = std::abs(ineq[pos + offset]);
916     expr = (expr + divisor - 1).floorDiv(divisor);
917     lbExprs.push_back(expr);
918   }
919 
920   SmallVector<AffineExpr, 4> ubExprs;
921   ubExprs.reserve(ubIndices.size() + eqIndices.size());
922   // Upper bound expressions.
923   for (auto idx : ubIndices) {
924     auto ineq = getInequality(idx);
925     // Extract the upper bound (in terms of other coeff's + const).
926     addCoeffs(ineq, ub);
927     auto expr =
928         getAffineExprFromFlatForm(ub, dimCount, symCount, localExprs, context);
929     expr = expr.floorDiv(std::abs(ineq[pos + offset]));
930     // Upper bound is exclusive.
931     ubExprs.push_back(expr + 1);
932   }
933 
934   // Equalities. It's both a lower and a upper bound.
935   SmallVector<int64_t, 4> b;
936   for (auto idx : eqIndices) {
937     auto eq = getEquality(idx);
938     addCoeffs(eq, b);
939     if (eq[pos + offset] > 0)
940       std::transform(b.begin(), b.end(), b.begin(), std::negate<int64_t>());
941 
942     // Extract the upper bound (in terms of other coeff's + const).
943     auto expr =
944         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
945     expr = expr.floorDiv(std::abs(eq[pos + offset]));
946     // Upper bound is exclusive.
947     ubExprs.push_back(expr + 1);
948     // Lower bound.
949     expr =
950         getAffineExprFromFlatForm(b, dimCount, symCount, localExprs, context);
951     expr = expr.ceilDiv(std::abs(eq[pos + offset]));
952     lbExprs.push_back(expr);
953   }
954 
955   auto lbMap = AffineMap::get(dimCount, symCount, lbExprs, context);
956   auto ubMap = AffineMap::get(dimCount, symCount, ubExprs, context);
957 
958   return {lbMap, ubMap};
959 }
960 
961 /// Computes the lower and upper bounds of the first 'num' dimensional
962 /// identifiers (starting at 'offset') as affine maps of the remaining
963 /// identifiers (dimensional and symbolic identifiers). Local identifiers are
964 /// themselves explicitly computed as affine functions of other identifiers in
965 /// this process if needed.
966 void FlatAffineValueConstraints::getSliceBounds(
967     unsigned offset, unsigned num, MLIRContext *context,
968     SmallVectorImpl<AffineMap> *lbMaps, SmallVectorImpl<AffineMap> *ubMaps) {
969   assert(num < getNumDimIds() && "invalid range");
970 
971   // Basic simplification.
972   normalizeConstraintsByGCD();
973 
974   LLVM_DEBUG(llvm::dbgs() << "getSliceBounds for first " << num
975                           << " identifiers\n");
976   LLVM_DEBUG(dump());
977 
978   // Record computed/detected identifiers.
979   SmallVector<AffineExpr, 8> memo(getNumIds());
980   // Initialize dimensional and symbolic identifiers.
981   for (unsigned i = 0, e = getNumDimIds(); i < e; i++) {
982     if (i < offset)
983       memo[i] = getAffineDimExpr(i, context);
984     else if (i >= offset + num)
985       memo[i] = getAffineDimExpr(i - num, context);
986   }
987   for (unsigned i = getNumDimIds(), e = getNumDimAndSymbolIds(); i < e; i++)
988     memo[i] = getAffineSymbolExpr(i - getNumDimIds(), context);
989 
990   bool changed;
991   do {
992     changed = false;
993     // Identify yet unknown identifiers as constants or mod's / floordiv's of
994     // other identifiers if possible.
995     for (unsigned pos = 0; pos < getNumIds(); pos++) {
996       if (memo[pos])
997         continue;
998 
999       auto lbConst = getConstantBound(BoundType::LB, pos);
1000       auto ubConst = getConstantBound(BoundType::UB, pos);
1001       if (lbConst.hasValue() && ubConst.hasValue()) {
1002         // Detect equality to a constant.
1003         if (lbConst.getValue() == ubConst.getValue()) {
1004           memo[pos] = getAffineConstantExpr(lbConst.getValue(), context);
1005           changed = true;
1006           continue;
1007         }
1008 
1009         // Detect an identifier as modulo of another identifier w.r.t a
1010         // constant.
1011         if (detectAsMod(*this, pos, lbConst.getValue(), ubConst.getValue(),
1012                         memo, context)) {
1013           changed = true;
1014           continue;
1015         }
1016       }
1017 
1018       // Detect an identifier as a floordiv of an affine function of other
1019       // identifiers (divisor is a positive constant).
1020       if (detectAsFloorDiv(*this, pos, context, memo)) {
1021         changed = true;
1022         continue;
1023       }
1024 
1025       // Detect an identifier as an expression of other identifiers.
1026       unsigned idx;
1027       if (!findConstraintWithNonZeroAt(pos, /*isEq=*/true, &idx)) {
1028         continue;
1029       }
1030 
1031       // Build AffineExpr solving for identifier 'pos' in terms of all others.
1032       auto expr = getAffineConstantExpr(0, context);
1033       unsigned j, e;
1034       for (j = 0, e = getNumIds(); j < e; ++j) {
1035         if (j == pos)
1036           continue;
1037         int64_t c = atEq(idx, j);
1038         if (c == 0)
1039           continue;
1040         // If any of the involved IDs hasn't been found yet, we can't proceed.
1041         if (!memo[j])
1042           break;
1043         expr = expr + memo[j] * c;
1044       }
1045       if (j < e)
1046         // Can't construct expression as it depends on a yet uncomputed
1047         // identifier.
1048         continue;
1049 
1050       // Add constant term to AffineExpr.
1051       expr = expr + atEq(idx, getNumIds());
1052       int64_t vPos = atEq(idx, pos);
1053       assert(vPos != 0 && "expected non-zero here");
1054       if (vPos > 0)
1055         expr = (-expr).floorDiv(vPos);
1056       else
1057         // vPos < 0.
1058         expr = expr.floorDiv(-vPos);
1059       // Successfully constructed expression.
1060       memo[pos] = expr;
1061       changed = true;
1062     }
1063     // This loop is guaranteed to reach a fixed point - since once an
1064     // identifier's explicit form is computed (in memo[pos]), it's not updated
1065     // again.
1066   } while (changed);
1067 
1068   // Set the lower and upper bound maps for all the identifiers that were
1069   // computed as affine expressions of the rest as the "detected expr" and
1070   // "detected expr + 1" respectively; set the undetected ones to null.
1071   Optional<FlatAffineValueConstraints> tmpClone;
1072   for (unsigned pos = 0; pos < num; pos++) {
1073     unsigned numMapDims = getNumDimIds() - num;
1074     unsigned numMapSymbols = getNumSymbolIds();
1075     AffineExpr expr = memo[pos + offset];
1076     if (expr)
1077       expr = simplifyAffineExpr(expr, numMapDims, numMapSymbols);
1078 
1079     AffineMap &lbMap = (*lbMaps)[pos];
1080     AffineMap &ubMap = (*ubMaps)[pos];
1081 
1082     if (expr) {
1083       lbMap = AffineMap::get(numMapDims, numMapSymbols, expr);
1084       ubMap = AffineMap::get(numMapDims, numMapSymbols, expr + 1);
1085     } else {
1086       // TODO: Whenever there are local identifiers in the dependence
1087       // constraints, we'll conservatively over-approximate, since we don't
1088       // always explicitly compute them above (in the while loop).
1089       if (getNumLocalIds() == 0) {
1090         // Work on a copy so that we don't update this constraint system.
1091         if (!tmpClone) {
1092           tmpClone.emplace(FlatAffineValueConstraints(*this));
1093           // Removing redundant inequalities is necessary so that we don't get
1094           // redundant loop bounds.
1095           tmpClone->removeRedundantInequalities();
1096         }
1097         std::tie(lbMap, ubMap) = tmpClone->getLowerAndUpperBound(
1098             pos, offset, num, getNumDimIds(), /*localExprs=*/{}, context);
1099       }
1100 
1101       // If the above fails, we'll just use the constant lower bound and the
1102       // constant upper bound (if they exist) as the slice bounds.
1103       // TODO: being conservative for the moment in cases that
1104       // lead to multiple bounds - until getConstDifference in LoopFusion.cpp is
1105       // fixed (b/126426796).
1106       if (!lbMap || lbMap.getNumResults() > 1) {
1107         LLVM_DEBUG(llvm::dbgs()
1108                    << "WARNING: Potentially over-approximating slice lb\n");
1109         auto lbConst = getConstantBound(BoundType::LB, pos + offset);
1110         if (lbConst.hasValue()) {
1111           lbMap = AffineMap::get(
1112               numMapDims, numMapSymbols,
1113               getAffineConstantExpr(lbConst.getValue(), context));
1114         }
1115       }
1116       if (!ubMap || ubMap.getNumResults() > 1) {
1117         LLVM_DEBUG(llvm::dbgs()
1118                    << "WARNING: Potentially over-approximating slice ub\n");
1119         auto ubConst = getConstantBound(BoundType::UB, pos + offset);
1120         if (ubConst.hasValue()) {
1121           (ubMap) = AffineMap::get(
1122               numMapDims, numMapSymbols,
1123               getAffineConstantExpr(ubConst.getValue() + 1, context));
1124         }
1125       }
1126     }
1127     LLVM_DEBUG(llvm::dbgs()
1128                << "lb map for pos = " << Twine(pos + offset) << ", expr: ");
1129     LLVM_DEBUG(lbMap.dump(););
1130     LLVM_DEBUG(llvm::dbgs()
1131                << "ub map for pos = " << Twine(pos + offset) << ", expr: ");
1132     LLVM_DEBUG(ubMap.dump(););
1133   }
1134 }
1135 
1136 LogicalResult FlatAffineValueConstraints::flattenAlignedMapAndMergeLocals(
1137     AffineMap map, std::vector<SmallVector<int64_t, 8>> *flattenedExprs) {
1138   FlatAffineValueConstraints localCst;
1139   if (failed(getFlattenedAffineExprs(map, flattenedExprs, &localCst))) {
1140     LLVM_DEBUG(llvm::dbgs()
1141                << "composition unimplemented for semi-affine maps\n");
1142     return failure();
1143   }
1144 
1145   // Add localCst information.
1146   if (localCst.getNumLocalIds() > 0) {
1147     unsigned numLocalIds = getNumLocalIds();
1148     // Insert local dims of localCst at the beginning.
1149     insertLocalId(/*pos=*/0, /*num=*/localCst.getNumLocalIds());
1150     // Insert local dims of `this` at the end of localCst.
1151     localCst.appendLocalId(/*num=*/numLocalIds);
1152     // Dimensions of localCst and this constraint set match. Append localCst to
1153     // this constraint set.
1154     append(localCst);
1155   }
1156 
1157   return success();
1158 }
1159 
1160 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
1161                                                    AffineMap boundMap) {
1162   assert(boundMap.getNumDims() == getNumDimIds() && "dim mismatch");
1163   assert(boundMap.getNumSymbols() == getNumSymbolIds() && "symbol mismatch");
1164   assert(pos < getNumDimAndSymbolIds() && "invalid position");
1165 
1166   // Equality follows the logic of lower bound except that we add an equality
1167   // instead of an inequality.
1168   assert((type != BoundType::EQ || boundMap.getNumResults() == 1) &&
1169          "single result expected");
1170   bool lower = type == BoundType::LB || type == BoundType::EQ;
1171 
1172   std::vector<SmallVector<int64_t, 8>> flatExprs;
1173   if (failed(flattenAlignedMapAndMergeLocals(boundMap, &flatExprs)))
1174     return failure();
1175   assert(flatExprs.size() == boundMap.getNumResults());
1176 
1177   // Add one (in)equality for each result.
1178   for (const auto &flatExpr : flatExprs) {
1179     SmallVector<int64_t> ineq(getNumCols(), 0);
1180     // Dims and symbols.
1181     for (unsigned j = 0, e = boundMap.getNumInputs(); j < e; j++) {
1182       ineq[j] = lower ? -flatExpr[j] : flatExpr[j];
1183     }
1184     // Invalid bound: pos appears in `boundMap`.
1185     // TODO: This should be an assertion. Fix `addDomainFromSliceMaps` and/or
1186     // its callers to prevent invalid bounds from being added.
1187     if (ineq[pos] != 0)
1188       continue;
1189     ineq[pos] = lower ? 1 : -1;
1190     // Local columns of `ineq` are at the beginning.
1191     unsigned j = getNumDimIds() + getNumSymbolIds();
1192     unsigned end = flatExpr.size() - 1;
1193     for (unsigned i = boundMap.getNumInputs(); i < end; i++, j++) {
1194       ineq[j] = lower ? -flatExpr[i] : flatExpr[i];
1195     }
1196     // Constant term.
1197     ineq[getNumCols() - 1] =
1198         lower ? -flatExpr[flatExpr.size() - 1]
1199               // Upper bound in flattenedExpr is an exclusive one.
1200               : flatExpr[flatExpr.size() - 1] - 1;
1201     type == BoundType::EQ ? addEquality(ineq) : addInequality(ineq);
1202   }
1203 
1204   return success();
1205 }
1206 
1207 AffineMap
1208 FlatAffineValueConstraints::computeAlignedMap(AffineMap map,
1209                                               ValueRange operands) const {
1210   assert(map.getNumInputs() == operands.size() && "number of inputs mismatch");
1211 
1212   SmallVector<Value> dims, syms;
1213 #ifndef NDEBUG
1214   SmallVector<Value> newSyms;
1215   SmallVector<Value> *newSymsPtr = &newSyms;
1216 #else
1217   SmallVector<Value> *newSymsPtr = nullptr;
1218 #endif // NDEBUG
1219 
1220   dims.reserve(getNumDimIds());
1221   syms.reserve(getNumSymbolIds());
1222   for (unsigned i = getIdKindOffset(IdKind::SetDim),
1223                 e = getIdKindEnd(IdKind::SetDim);
1224        i < e; ++i)
1225     dims.push_back(values[i] ? *values[i] : Value());
1226   for (unsigned i = getIdKindOffset(IdKind::Symbol),
1227                 e = getIdKindEnd(IdKind::Symbol);
1228        i < e; ++i)
1229     syms.push_back(values[i] ? *values[i] : Value());
1230 
1231   AffineMap alignedMap =
1232       alignAffineMapWithValues(map, operands, dims, syms, newSymsPtr);
1233   // All symbols are already part of this FlatAffineConstraints.
1234   assert(syms.size() == newSymsPtr->size() && "unexpected new/missing symbols");
1235   assert(std::equal(syms.begin(), syms.end(), newSymsPtr->begin()) &&
1236          "unexpected new/missing symbols");
1237   return alignedMap;
1238 }
1239 
1240 LogicalResult FlatAffineValueConstraints::addBound(BoundType type, unsigned pos,
1241                                                    AffineMap boundMap,
1242                                                    ValueRange boundOperands) {
1243   // Fully compose map and operands; canonicalize and simplify so that we
1244   // transitively get to terminal symbols or loop IVs.
1245   auto map = boundMap;
1246   SmallVector<Value, 4> operands(boundOperands.begin(), boundOperands.end());
1247   fullyComposeAffineMapAndOperands(&map, &operands);
1248   map = simplifyAffineMap(map);
1249   canonicalizeMapAndOperands(&map, &operands);
1250   for (auto operand : operands)
1251     addInductionVarOrTerminalSymbol(operand);
1252   return addBound(type, pos, computeAlignedMap(map, operands));
1253 }
1254 
1255 // Adds slice lower bounds represented by lower bounds in 'lbMaps' and upper
1256 // bounds in 'ubMaps' to each value in `values' that appears in the constraint
1257 // system. Note that both lower/upper bounds share the same operand list
1258 // 'operands'.
1259 // This function assumes 'values.size' == 'lbMaps.size' == 'ubMaps.size', and
1260 // skips any null AffineMaps in 'lbMaps' or 'ubMaps'.
1261 // Note that both lower/upper bounds use operands from 'operands'.
1262 // Returns failure for unimplemented cases such as semi-affine expressions or
1263 // expressions with mod/floordiv.
1264 LogicalResult FlatAffineValueConstraints::addSliceBounds(
1265     ArrayRef<Value> values, ArrayRef<AffineMap> lbMaps,
1266     ArrayRef<AffineMap> ubMaps, ArrayRef<Value> operands) {
1267   assert(values.size() == lbMaps.size());
1268   assert(lbMaps.size() == ubMaps.size());
1269 
1270   for (unsigned i = 0, e = lbMaps.size(); i < e; ++i) {
1271     unsigned pos;
1272     if (!findId(values[i], &pos))
1273       continue;
1274 
1275     AffineMap lbMap = lbMaps[i];
1276     AffineMap ubMap = ubMaps[i];
1277     assert(!lbMap || lbMap.getNumInputs() == operands.size());
1278     assert(!ubMap || ubMap.getNumInputs() == operands.size());
1279 
1280     // Check if this slice is just an equality along this dimension.
1281     if (lbMap && ubMap && lbMap.getNumResults() == 1 &&
1282         ubMap.getNumResults() == 1 &&
1283         lbMap.getResult(0) + 1 == ubMap.getResult(0)) {
1284       if (failed(addBound(BoundType::EQ, pos, lbMap, operands)))
1285         return failure();
1286       continue;
1287     }
1288 
1289     // If lower or upper bound maps are null or provide no results, it implies
1290     // that the source loop was not at all sliced, and the entire loop will be a
1291     // part of the slice.
1292     if (lbMap && lbMap.getNumResults() != 0 && ubMap &&
1293         ubMap.getNumResults() != 0) {
1294       if (failed(addBound(BoundType::LB, pos, lbMap, operands)))
1295         return failure();
1296       if (failed(addBound(BoundType::UB, pos, ubMap, operands)))
1297         return failure();
1298     } else {
1299       auto loop = getForInductionVarOwner(values[i]);
1300       if (failed(this->addAffineForOpDomain(loop)))
1301         return failure();
1302     }
1303   }
1304   return success();
1305 }
1306 
1307 bool FlatAffineValueConstraints::findId(Value val, unsigned *pos) const {
1308   unsigned i = 0;
1309   for (const auto &mayBeId : values) {
1310     if (mayBeId.hasValue() && mayBeId.getValue() == val) {
1311       *pos = i;
1312       return true;
1313     }
1314     i++;
1315   }
1316   return false;
1317 }
1318 
1319 bool FlatAffineValueConstraints::containsId(Value val) const {
1320   return llvm::any_of(values, [&](const Optional<Value> &mayBeId) {
1321     return mayBeId.hasValue() && mayBeId.getValue() == val;
1322   });
1323 }
1324 
1325 void FlatAffineValueConstraints::swapId(unsigned posA, unsigned posB) {
1326   IntegerPolyhedron::swapId(posA, posB);
1327   std::swap(values[posA], values[posB]);
1328 }
1329 
1330 void FlatAffineValueConstraints::addBound(BoundType type, Value val,
1331                                           int64_t value) {
1332   unsigned pos;
1333   if (!findId(val, &pos))
1334     // This is a pre-condition for this method.
1335     assert(0 && "id not found");
1336   addBound(type, pos, value);
1337 }
1338 
1339 void FlatAffineValueConstraints::printSpace(raw_ostream &os) const {
1340   IntegerPolyhedron::printSpace(os);
1341   os << "(";
1342   for (unsigned i = 0, e = getNumIds(); i < e; i++) {
1343     if (hasValue(i))
1344       os << "Value ";
1345     else
1346       os << "None ";
1347   }
1348   os << " const)\n";
1349 }
1350 
1351 void FlatAffineValueConstraints::clearAndCopyFrom(
1352     const IntegerRelation &other) {
1353 
1354   if (auto *otherValueSet =
1355           dyn_cast<const FlatAffineValueConstraints>(&other)) {
1356     *this = *otherValueSet;
1357   } else {
1358     *static_cast<IntegerRelation *>(this) = other;
1359     values.clear();
1360     values.resize(getNumIds(), None);
1361   }
1362 }
1363 
1364 void FlatAffineValueConstraints::fourierMotzkinEliminate(
1365     unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
1366   SmallVector<Optional<Value>, 8> newVals;
1367   newVals.reserve(getNumIds() - 1);
1368   newVals.append(values.begin(), values.begin() + pos);
1369   newVals.append(values.begin() + pos + 1, values.end());
1370   // Note: Base implementation discards all associated Values.
1371   IntegerPolyhedron::fourierMotzkinEliminate(pos, darkShadow,
1372                                              isResultIntegerExact);
1373   values = newVals;
1374   assert(values.size() == getNumIds());
1375 }
1376 
1377 void FlatAffineValueConstraints::projectOut(Value val) {
1378   unsigned pos;
1379   bool ret = findId(val, &pos);
1380   assert(ret);
1381   (void)ret;
1382   fourierMotzkinEliminate(pos);
1383 }
1384 
1385 LogicalResult FlatAffineValueConstraints::unionBoundingBox(
1386     const FlatAffineValueConstraints &otherCst) {
1387   assert(otherCst.getNumDimIds() == getNumDimIds() && "dims mismatch");
1388   assert(otherCst.getMaybeValues()
1389              .slice(0, getNumDimIds())
1390              .equals(getMaybeValues().slice(0, getNumDimIds())) &&
1391          "dim values mismatch");
1392   assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
1393   assert(getNumLocalIds() == 0 && "local ids not supported yet here");
1394 
1395   // Align `other` to this.
1396   if (!areIdsAligned(*this, otherCst)) {
1397     FlatAffineValueConstraints otherCopy(otherCst);
1398     mergeAndAlignIds(/*offset=*/getNumDimIds(), this, &otherCopy);
1399     return IntegerPolyhedron::unionBoundingBox(otherCopy);
1400   }
1401 
1402   return IntegerPolyhedron::unionBoundingBox(otherCst);
1403 }
1404 
1405 /// Compute an explicit representation for local vars. For all systems coming
1406 /// from MLIR integer sets, maps, or expressions where local vars were
1407 /// introduced to model floordivs and mods, this always succeeds.
1408 static LogicalResult computeLocalVars(const FlatAffineValueConstraints &cst,
1409                                       SmallVectorImpl<AffineExpr> &memo,
1410                                       MLIRContext *context) {
1411   unsigned numDims = cst.getNumDimIds();
1412   unsigned numSyms = cst.getNumSymbolIds();
1413 
1414   // Initialize dimensional and symbolic identifiers.
1415   for (unsigned i = 0; i < numDims; i++)
1416     memo[i] = getAffineDimExpr(i, context);
1417   for (unsigned i = numDims, e = numDims + numSyms; i < e; i++)
1418     memo[i] = getAffineSymbolExpr(i - numDims, context);
1419 
1420   bool changed;
1421   do {
1422     // Each time `changed` is true at the end of this iteration, one or more
1423     // local vars would have been detected as floordivs and set in memo; so the
1424     // number of null entries in memo[...] strictly reduces; so this converges.
1425     changed = false;
1426     for (unsigned i = 0, e = cst.getNumLocalIds(); i < e; ++i)
1427       if (!memo[numDims + numSyms + i] &&
1428           detectAsFloorDiv(cst, /*pos=*/numDims + numSyms + i, context, memo))
1429         changed = true;
1430   } while (changed);
1431 
1432   ArrayRef<AffineExpr> localExprs =
1433       ArrayRef<AffineExpr>(memo).take_back(cst.getNumLocalIds());
1434   return success(
1435       llvm::all_of(localExprs, [](AffineExpr expr) { return expr; }));
1436 }
1437 
1438 void FlatAffineValueConstraints::getIneqAsAffineValueMap(
1439     unsigned pos, unsigned ineqPos, AffineValueMap &vmap,
1440     MLIRContext *context) const {
1441   unsigned numDims = getNumDimIds();
1442   unsigned numSyms = getNumSymbolIds();
1443 
1444   assert(pos < numDims && "invalid position");
1445   assert(ineqPos < getNumInequalities() && "invalid inequality position");
1446 
1447   // Get expressions for local vars.
1448   SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
1449   if (failed(computeLocalVars(*this, memo, context)))
1450     assert(false &&
1451            "one or more local exprs do not have an explicit representation");
1452   auto localExprs = ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
1453 
1454   // Compute the AffineExpr lower/upper bound for this inequality.
1455   ArrayRef<int64_t> inequality = getInequality(ineqPos);
1456   SmallVector<int64_t, 8> bound;
1457   bound.reserve(getNumCols() - 1);
1458   // Everything other than the coefficient at `pos`.
1459   bound.append(inequality.begin(), inequality.begin() + pos);
1460   bound.append(inequality.begin() + pos + 1, inequality.end());
1461 
1462   if (inequality[pos] > 0)
1463     // Lower bound.
1464     std::transform(bound.begin(), bound.end(), bound.begin(),
1465                    std::negate<int64_t>());
1466   else
1467     // Upper bound (which is exclusive).
1468     bound.back() += 1;
1469 
1470   // Convert to AffineExpr (tree) form.
1471   auto boundExpr = getAffineExprFromFlatForm(bound, numDims - 1, numSyms,
1472                                              localExprs, context);
1473 
1474   // Get the values to bind to this affine expr (all dims and symbols).
1475   SmallVector<Value, 4> operands;
1476   getValues(0, pos, &operands);
1477   SmallVector<Value, 4> trailingOperands;
1478   getValues(pos + 1, getNumDimAndSymbolIds(), &trailingOperands);
1479   operands.append(trailingOperands.begin(), trailingOperands.end());
1480   vmap.reset(AffineMap::get(numDims - 1, numSyms, boundExpr), operands);
1481 }
1482 
1483 IntegerSet
1484 FlatAffineValueConstraints::getAsIntegerSet(MLIRContext *context) const {
1485   if (getNumConstraints() == 0)
1486     // Return universal set (always true): 0 == 0.
1487     return IntegerSet::get(getNumDimIds(), getNumSymbolIds(),
1488                            getAffineConstantExpr(/*constant=*/0, context),
1489                            /*eqFlags=*/true);
1490 
1491   // Construct local references.
1492   SmallVector<AffineExpr, 8> memo(getNumIds(), AffineExpr());
1493 
1494   if (failed(computeLocalVars(*this, memo, context))) {
1495     // Check if the local variables without an explicit representation have
1496     // zero coefficients everywhere.
1497     SmallVector<unsigned> noLocalRepVars;
1498     unsigned numDimsSymbols = getNumDimAndSymbolIds();
1499     for (unsigned i = numDimsSymbols, e = getNumIds(); i < e; ++i) {
1500       if (!memo[i] && !isColZero(/*pos=*/i))
1501         noLocalRepVars.push_back(i - numDimsSymbols);
1502     }
1503     if (!noLocalRepVars.empty()) {
1504       LLVM_DEBUG({
1505         llvm::dbgs() << "local variables at position(s) ";
1506         llvm::interleaveComma(noLocalRepVars, llvm::dbgs());
1507         llvm::dbgs() << " do not have an explicit representation in:\n";
1508         this->dump();
1509       });
1510       return IntegerSet();
1511     }
1512   }
1513 
1514   ArrayRef<AffineExpr> localExprs =
1515       ArrayRef<AffineExpr>(memo).take_back(getNumLocalIds());
1516 
1517   // Construct the IntegerSet from the equalities/inequalities.
1518   unsigned numDims = getNumDimIds();
1519   unsigned numSyms = getNumSymbolIds();
1520 
1521   SmallVector<bool, 16> eqFlags(getNumConstraints());
1522   std::fill(eqFlags.begin(), eqFlags.begin() + getNumEqualities(), true);
1523   std::fill(eqFlags.begin() + getNumEqualities(), eqFlags.end(), false);
1524 
1525   SmallVector<AffineExpr, 8> exprs;
1526   exprs.reserve(getNumConstraints());
1527 
1528   for (unsigned i = 0, e = getNumEqualities(); i < e; ++i)
1529     exprs.push_back(getAffineExprFromFlatForm(getEquality(i), numDims, numSyms,
1530                                               localExprs, context));
1531   for (unsigned i = 0, e = getNumInequalities(); i < e; ++i)
1532     exprs.push_back(getAffineExprFromFlatForm(getInequality(i), numDims,
1533                                               numSyms, localExprs, context));
1534   return IntegerSet::get(numDims, numSyms, exprs, eqFlags);
1535 }
1536 
1537 AffineMap mlir::alignAffineMapWithValues(AffineMap map, ValueRange operands,
1538                                          ValueRange dims, ValueRange syms,
1539                                          SmallVector<Value> *newSyms) {
1540   assert(operands.size() == map.getNumInputs() &&
1541          "expected same number of operands and map inputs");
1542   MLIRContext *ctx = map.getContext();
1543   Builder builder(ctx);
1544   SmallVector<AffineExpr> dimReplacements(map.getNumDims(), {});
1545   unsigned numSymbols = syms.size();
1546   SmallVector<AffineExpr> symReplacements(map.getNumSymbols(), {});
1547   if (newSyms) {
1548     newSyms->clear();
1549     newSyms->append(syms.begin(), syms.end());
1550   }
1551 
1552   for (const auto &operand : llvm::enumerate(operands)) {
1553     // Compute replacement dim/sym of operand.
1554     AffineExpr replacement;
1555     auto dimIt = std::find(dims.begin(), dims.end(), operand.value());
1556     auto symIt = std::find(syms.begin(), syms.end(), operand.value());
1557     if (dimIt != dims.end()) {
1558       replacement =
1559           builder.getAffineDimExpr(std::distance(dims.begin(), dimIt));
1560     } else if (symIt != syms.end()) {
1561       replacement =
1562           builder.getAffineSymbolExpr(std::distance(syms.begin(), symIt));
1563     } else {
1564       // This operand is neither a dimension nor a symbol. Add it as a new
1565       // symbol.
1566       replacement = builder.getAffineSymbolExpr(numSymbols++);
1567       if (newSyms)
1568         newSyms->push_back(operand.value());
1569     }
1570     // Add to corresponding replacements vector.
1571     if (operand.index() < map.getNumDims()) {
1572       dimReplacements[operand.index()] = replacement;
1573     } else {
1574       symReplacements[operand.index() - map.getNumDims()] = replacement;
1575     }
1576   }
1577 
1578   return map.replaceDimsAndSymbols(dimReplacements, symReplacements,
1579                                    dims.size(), numSymbols);
1580 }
1581 
1582 FlatAffineValueConstraints FlatAffineRelation::getDomainSet() const {
1583   FlatAffineValueConstraints domain = *this;
1584   // Convert all range variables to local variables.
1585   domain.convertToLocal(IdKind::SetDim, getNumDomainDims(),
1586                         getNumDomainDims() + getNumRangeDims());
1587   return domain;
1588 }
1589 
1590 FlatAffineValueConstraints FlatAffineRelation::getRangeSet() const {
1591   FlatAffineValueConstraints range = *this;
1592   // Convert all domain variables to local variables.
1593   range.convertToLocal(IdKind::SetDim, 0, getNumDomainDims());
1594   return range;
1595 }
1596 
1597 void FlatAffineRelation::compose(const FlatAffineRelation &other) {
1598   assert(getNumDomainDims() == other.getNumRangeDims() &&
1599          "Domain of this and range of other do not match");
1600   assert(std::equal(values.begin(), values.begin() + getNumDomainDims(),
1601                     other.values.begin() + other.getNumDomainDims()) &&
1602          "Domain of this and range of other do not match");
1603 
1604   FlatAffineRelation rel = other;
1605 
1606   // Convert `rel` from
1607   //    [otherDomain] -> [otherRange]
1608   // to
1609   //    [otherDomain] -> [otherRange thisRange]
1610   // and `this` from
1611   //    [thisDomain] -> [thisRange]
1612   // to
1613   //    [otherDomain thisDomain] -> [thisRange].
1614   unsigned removeDims = rel.getNumRangeDims();
1615   insertDomainId(0, rel.getNumDomainDims());
1616   rel.appendRangeId(getNumRangeDims());
1617 
1618   // Merge symbol and local identifiers.
1619   mergeSymbolIds(rel);
1620   mergeLocalIds(rel);
1621 
1622   // Convert `rel` from [otherDomain] -> [otherRange thisRange] to
1623   // [otherDomain] -> [thisRange] by converting first otherRange range ids
1624   // to local ids.
1625   rel.convertToLocal(IdKind::SetDim, rel.getNumDomainDims(),
1626                      rel.getNumDomainDims() + removeDims);
1627   // Convert `this` from [otherDomain thisDomain] -> [thisRange] to
1628   // [otherDomain] -> [thisRange] by converting last thisDomain domain ids
1629   // to local ids.
1630   convertToLocal(IdKind::SetDim, getNumDomainDims() - removeDims,
1631                  getNumDomainDims());
1632 
1633   auto thisMaybeValues = getMaybeDimValues();
1634   auto relMaybeValues = rel.getMaybeDimValues();
1635 
1636   // Add and match domain of `rel` to domain of `this`.
1637   for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
1638     if (relMaybeValues[i].hasValue())
1639       setValue(i, relMaybeValues[i].getValue());
1640   // Add and match range of `this` to range of `rel`.
1641   for (unsigned i = 0, e = getNumRangeDims(); i < e; ++i) {
1642     unsigned rangeIdx = rel.getNumDomainDims() + i;
1643     if (thisMaybeValues[rangeIdx].hasValue())
1644       rel.setValue(rangeIdx, thisMaybeValues[rangeIdx].getValue());
1645   }
1646 
1647   // Append `this` to `rel` and simplify constraints.
1648   rel.append(*this);
1649   rel.removeRedundantLocalVars();
1650 
1651   *this = rel;
1652 }
1653 
1654 void FlatAffineRelation::inverse() {
1655   unsigned oldDomain = getNumDomainDims();
1656   unsigned oldRange = getNumRangeDims();
1657   // Add new range ids.
1658   appendRangeId(oldDomain);
1659   // Swap new ids with domain.
1660   for (unsigned i = 0; i < oldDomain; ++i)
1661     swapId(i, oldDomain + oldRange + i);
1662   // Remove the swapped domain.
1663   removeIdRange(0, oldDomain);
1664   // Set domain and range as inverse.
1665   numDomainDims = oldRange;
1666   numRangeDims = oldDomain;
1667 }
1668 
1669 void FlatAffineRelation::insertDomainId(unsigned pos, unsigned num) {
1670   assert(pos <= getNumDomainDims() &&
1671          "Id cannot be inserted at invalid position");
1672   insertDimId(pos, num);
1673   numDomainDims += num;
1674 }
1675 
1676 void FlatAffineRelation::insertRangeId(unsigned pos, unsigned num) {
1677   assert(pos <= getNumRangeDims() &&
1678          "Id cannot be inserted at invalid position");
1679   insertDimId(getNumDomainDims() + pos, num);
1680   numRangeDims += num;
1681 }
1682 
1683 void FlatAffineRelation::appendDomainId(unsigned num) {
1684   insertDimId(getNumDomainDims(), num);
1685   numDomainDims += num;
1686 }
1687 
1688 void FlatAffineRelation::appendRangeId(unsigned num) {
1689   insertDimId(getNumDimIds(), num);
1690   numRangeDims += num;
1691 }
1692 
1693 void FlatAffineRelation::removeIdRange(IdKind kind, unsigned idStart,
1694                                        unsigned idLimit) {
1695   assert(idLimit <= getNumIdKind(kind));
1696   if (idStart >= idLimit)
1697     return;
1698 
1699   FlatAffineValueConstraints::removeIdRange(kind, idStart, idLimit);
1700 
1701   // If kind is not SetDim, domain and range don't need to be updated.
1702   if (kind != IdKind::SetDim)
1703     return;
1704 
1705   // Compute number of domain and range identifiers to remove. This is done by
1706   // intersecting the range of domain/range ids with range of ids to remove.
1707   unsigned intersectDomainLHS = std::min(idLimit, getNumDomainDims());
1708   unsigned intersectDomainRHS = idStart;
1709   unsigned intersectRangeLHS = std::min(idLimit, getNumDimIds());
1710   unsigned intersectRangeRHS = std::max(idStart, getNumDomainDims());
1711 
1712   if (intersectDomainLHS > intersectDomainRHS)
1713     numDomainDims -= intersectDomainLHS - intersectDomainRHS;
1714   if (intersectRangeLHS > intersectRangeRHS)
1715     numRangeDims -= intersectRangeLHS - intersectRangeRHS;
1716 }
1717 
1718 LogicalResult mlir::getRelationFromMap(AffineMap &map,
1719                                        FlatAffineRelation &rel) {
1720   // Get flattened affine expressions.
1721   std::vector<SmallVector<int64_t, 8>> flatExprs;
1722   FlatAffineValueConstraints localVarCst;
1723   if (failed(getFlattenedAffineExprs(map, &flatExprs, &localVarCst)))
1724     return failure();
1725 
1726   unsigned oldDimNum = localVarCst.getNumDimIds();
1727   unsigned oldCols = localVarCst.getNumCols();
1728   unsigned numRangeIds = map.getNumResults();
1729   unsigned numDomainIds = map.getNumDims();
1730 
1731   // Add range as the new expressions.
1732   localVarCst.appendDimId(numRangeIds);
1733 
1734   // Add equalities between source and range.
1735   SmallVector<int64_t, 8> eq(localVarCst.getNumCols());
1736   for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
1737     // Zero fill.
1738     std::fill(eq.begin(), eq.end(), 0);
1739     // Fill equality.
1740     for (unsigned j = 0, f = oldDimNum; j < f; ++j)
1741       eq[j] = flatExprs[i][j];
1742     for (unsigned j = oldDimNum, f = oldCols; j < f; ++j)
1743       eq[j + numRangeIds] = flatExprs[i][j];
1744     // Set this dimension to -1 to equate lhs and rhs and add equality.
1745     eq[numDomainIds + i] = -1;
1746     localVarCst.addEquality(eq);
1747   }
1748 
1749   // Create relation and return success.
1750   rel = FlatAffineRelation(numDomainIds, numRangeIds, localVarCst);
1751   return success();
1752 }
1753 
1754 LogicalResult mlir::getRelationFromMap(const AffineValueMap &map,
1755                                        FlatAffineRelation &rel) {
1756 
1757   AffineMap affineMap = map.getAffineMap();
1758   if (failed(getRelationFromMap(affineMap, rel)))
1759     return failure();
1760 
1761   // Set symbol values for domain dimensions and symbols.
1762   for (unsigned i = 0, e = rel.getNumDomainDims(); i < e; ++i)
1763     rel.setValue(i, map.getOperand(i));
1764   for (unsigned i = rel.getNumDimIds(), e = rel.getNumDimAndSymbolIds(); i < e;
1765        ++i)
1766     rel.setValue(i, map.getOperand(i - rel.getNumRangeDims()));
1767 
1768   return success();
1769 }
1770