1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/IR/AffineExpr.h"
12 #include "AffineExprDetail.h"
13 #include "mlir/IR/AffineExprVisitor.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/Support/MathExtras.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 MLIRContext *AffineExpr::getContext() const { return expr->context; }
24 
25 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
26 
27 /// Walk all of the AffineExprs in this subgraph in postorder.
28 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
29   struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
30     std::function<void(AffineExpr)> callback;
31 
32     AffineExprWalker(std::function<void(AffineExpr)> callback)
33         : callback(std::move(callback)) {}
34 
35     void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
36     void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
37     void visitDimExpr(AffineDimExpr expr) { callback(expr); }
38     void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
39   };
40 
41   AffineExprWalker(std::move(callback)).walkPostOrder(*this);
42 }
43 
44 // Dispatch affine expression construction based on kind.
45 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
46                                        AffineExpr rhs) {
47   if (kind == AffineExprKind::Add)
48     return lhs + rhs;
49   if (kind == AffineExprKind::Mul)
50     return lhs * rhs;
51   if (kind == AffineExprKind::FloorDiv)
52     return lhs.floorDiv(rhs);
53   if (kind == AffineExprKind::CeilDiv)
54     return lhs.ceilDiv(rhs);
55   if (kind == AffineExprKind::Mod)
56     return lhs % rhs;
57 
58   llvm_unreachable("unknown binary operation on affine expressions");
59 }
60 
61 /// This method substitutes any uses of dimensions and symbols (e.g.
62 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
63 AffineExpr
64 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
65                                   ArrayRef<AffineExpr> symReplacements) const {
66   switch (getKind()) {
67   case AffineExprKind::Constant:
68     return *this;
69   case AffineExprKind::DimId: {
70     unsigned dimId = cast<AffineDimExpr>().getPosition();
71     if (dimId >= dimReplacements.size())
72       return *this;
73     return dimReplacements[dimId];
74   }
75   case AffineExprKind::SymbolId: {
76     unsigned symId = cast<AffineSymbolExpr>().getPosition();
77     if (symId >= symReplacements.size())
78       return *this;
79     return symReplacements[symId];
80   }
81   case AffineExprKind::Add:
82   case AffineExprKind::Mul:
83   case AffineExprKind::FloorDiv:
84   case AffineExprKind::CeilDiv:
85   case AffineExprKind::Mod:
86     auto binOp = cast<AffineBinaryOpExpr>();
87     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
88     auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
89     auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
90     if (newLHS == lhs && newRHS == rhs)
91       return *this;
92     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
93   }
94   llvm_unreachable("Unknown AffineExpr");
95 }
96 
97 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
98   return replaceDimsAndSymbols(dimReplacements, {});
99 }
100 
101 AffineExpr
102 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
103   return replaceDimsAndSymbols({}, symReplacements);
104 }
105 
106 /// Replace dims[offset ... numDims)
107 /// by dims[offset + shift ... shift + numDims).
108 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
109                                  unsigned offset) const {
110   SmallVector<AffineExpr, 4> dims;
111   for (unsigned idx = 0; idx < offset; ++idx)
112     dims.push_back(getAffineDimExpr(idx, getContext()));
113   for (unsigned idx = offset; idx < numDims; ++idx)
114     dims.push_back(getAffineDimExpr(idx + shift, getContext()));
115   return replaceDimsAndSymbols(dims, {});
116 }
117 
118 /// Replace symbols[offset ... numSymbols)
119 /// by symbols[offset + shift ... shift + numSymbols).
120 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
121                                     unsigned offset) const {
122   SmallVector<AffineExpr, 4> symbols;
123   for (unsigned idx = 0; idx < offset; ++idx)
124     symbols.push_back(getAffineSymbolExpr(idx, getContext()));
125   for (unsigned idx = offset; idx < numSymbols; ++idx)
126     symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
127   return replaceDimsAndSymbols({}, symbols);
128 }
129 
130 /// Sparse replace method. Return the modified expression tree.
131 AffineExpr
132 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
133   auto it = map.find(*this);
134   if (it != map.end())
135     return it->second;
136   switch (getKind()) {
137   default:
138     return *this;
139   case AffineExprKind::Add:
140   case AffineExprKind::Mul:
141   case AffineExprKind::FloorDiv:
142   case AffineExprKind::CeilDiv:
143   case AffineExprKind::Mod:
144     auto binOp = cast<AffineBinaryOpExpr>();
145     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
146     auto newLHS = lhs.replace(map);
147     auto newRHS = rhs.replace(map);
148     if (newLHS == lhs && newRHS == rhs)
149       return *this;
150     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
151   }
152   llvm_unreachable("Unknown AffineExpr");
153 }
154 
155 /// Sparse replace method. Return the modified expression tree.
156 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
157   DenseMap<AffineExpr, AffineExpr> map;
158   map.insert(std::make_pair(expr, replacement));
159   return replace(map);
160 }
161 /// Returns true if this expression is made out of only symbols and
162 /// constants (no dimensional identifiers).
163 bool AffineExpr::isSymbolicOrConstant() const {
164   switch (getKind()) {
165   case AffineExprKind::Constant:
166     return true;
167   case AffineExprKind::DimId:
168     return false;
169   case AffineExprKind::SymbolId:
170     return true;
171 
172   case AffineExprKind::Add:
173   case AffineExprKind::Mul:
174   case AffineExprKind::FloorDiv:
175   case AffineExprKind::CeilDiv:
176   case AffineExprKind::Mod: {
177     auto expr = this->cast<AffineBinaryOpExpr>();
178     return expr.getLHS().isSymbolicOrConstant() &&
179            expr.getRHS().isSymbolicOrConstant();
180   }
181   }
182   llvm_unreachable("Unknown AffineExpr");
183 }
184 
185 /// Returns true if this is a pure affine expression, i.e., multiplication,
186 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
187 bool AffineExpr::isPureAffine() const {
188   switch (getKind()) {
189   case AffineExprKind::SymbolId:
190   case AffineExprKind::DimId:
191   case AffineExprKind::Constant:
192     return true;
193   case AffineExprKind::Add: {
194     auto op = cast<AffineBinaryOpExpr>();
195     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
196   }
197 
198   case AffineExprKind::Mul: {
199     // TODO: Canonicalize the constants in binary operators to the RHS when
200     // possible, allowing this to merge into the next case.
201     auto op = cast<AffineBinaryOpExpr>();
202     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
203            (op.getLHS().template isa<AffineConstantExpr>() ||
204             op.getRHS().template isa<AffineConstantExpr>());
205   }
206   case AffineExprKind::FloorDiv:
207   case AffineExprKind::CeilDiv:
208   case AffineExprKind::Mod: {
209     auto op = cast<AffineBinaryOpExpr>();
210     return op.getLHS().isPureAffine() &&
211            op.getRHS().template isa<AffineConstantExpr>();
212   }
213   }
214   llvm_unreachable("Unknown AffineExpr");
215 }
216 
217 // Returns the greatest known integral divisor of this affine expression.
218 int64_t AffineExpr::getLargestKnownDivisor() const {
219   AffineBinaryOpExpr binExpr(nullptr);
220   switch (getKind()) {
221   case AffineExprKind::CeilDiv:
222     LLVM_FALLTHROUGH;
223   case AffineExprKind::DimId:
224   case AffineExprKind::FloorDiv:
225   case AffineExprKind::SymbolId:
226     return 1;
227   case AffineExprKind::Constant:
228     return std::abs(this->cast<AffineConstantExpr>().getValue());
229   case AffineExprKind::Mul: {
230     binExpr = this->cast<AffineBinaryOpExpr>();
231     return binExpr.getLHS().getLargestKnownDivisor() *
232            binExpr.getRHS().getLargestKnownDivisor();
233   }
234   case AffineExprKind::Add:
235     LLVM_FALLTHROUGH;
236   case AffineExprKind::Mod: {
237     binExpr = cast<AffineBinaryOpExpr>();
238     return llvm::GreatestCommonDivisor64(
239         binExpr.getLHS().getLargestKnownDivisor(),
240         binExpr.getRHS().getLargestKnownDivisor());
241   }
242   }
243   llvm_unreachable("Unknown AffineExpr");
244 }
245 
246 bool AffineExpr::isMultipleOf(int64_t factor) const {
247   AffineBinaryOpExpr binExpr(nullptr);
248   uint64_t l, u;
249   switch (getKind()) {
250   case AffineExprKind::SymbolId:
251     LLVM_FALLTHROUGH;
252   case AffineExprKind::DimId:
253     return factor * factor == 1;
254   case AffineExprKind::Constant:
255     return cast<AffineConstantExpr>().getValue() % factor == 0;
256   case AffineExprKind::Mul: {
257     binExpr = cast<AffineBinaryOpExpr>();
258     // It's probably not worth optimizing this further (to not traverse the
259     // whole sub-tree under - it that would require a version of isMultipleOf
260     // that on a 'false' return also returns the largest known divisor).
261     return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
262            (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
263            (l * u) % factor == 0;
264   }
265   case AffineExprKind::Add:
266   case AffineExprKind::FloorDiv:
267   case AffineExprKind::CeilDiv:
268   case AffineExprKind::Mod: {
269     binExpr = cast<AffineBinaryOpExpr>();
270     return llvm::GreatestCommonDivisor64(
271                binExpr.getLHS().getLargestKnownDivisor(),
272                binExpr.getRHS().getLargestKnownDivisor()) %
273                factor ==
274            0;
275   }
276   }
277   llvm_unreachable("Unknown AffineExpr");
278 }
279 
280 bool AffineExpr::isFunctionOfDim(unsigned position) const {
281   if (getKind() == AffineExprKind::DimId) {
282     return *this == mlir::getAffineDimExpr(position, getContext());
283   }
284   if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
285     return expr.getLHS().isFunctionOfDim(position) ||
286            expr.getRHS().isFunctionOfDim(position);
287   }
288   return false;
289 }
290 
291 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
292   if (getKind() == AffineExprKind::SymbolId) {
293     return *this == mlir::getAffineSymbolExpr(position, getContext());
294   }
295   if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
296     return expr.getLHS().isFunctionOfSymbol(position) ||
297            expr.getRHS().isFunctionOfSymbol(position);
298   }
299   return false;
300 }
301 
302 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
303     : AffineExpr(ptr) {}
304 AffineExpr AffineBinaryOpExpr::getLHS() const {
305   return static_cast<ImplType *>(expr)->lhs;
306 }
307 AffineExpr AffineBinaryOpExpr::getRHS() const {
308   return static_cast<ImplType *>(expr)->rhs;
309 }
310 
311 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
312 unsigned AffineDimExpr::getPosition() const {
313   return static_cast<ImplType *>(expr)->position;
314 }
315 
316 /// Returns true if the expression is divisible by the given symbol with
317 /// position `symbolPos`. The argument `opKind` specifies here what kind of
318 /// division or mod operation called this division. It helps in implementing the
319 /// commutative property of the floordiv and ceildiv operations. If the argument
320 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
321 /// operation, then the commutative property can be used otherwise, the floordiv
322 /// operation is not divisible. The same argument holds for ceildiv operation.
323 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
324                                 AffineExprKind opKind) {
325   // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
326   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
327           opKind == AffineExprKind::CeilDiv) &&
328          "unexpected opKind");
329   switch (expr.getKind()) {
330   case AffineExprKind::Constant:
331     if (expr.cast<AffineConstantExpr>().getValue())
332       return false;
333     return true;
334   case AffineExprKind::DimId:
335     return false;
336   case AffineExprKind::SymbolId:
337     return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
338   // Checks divisibility by the given symbol for both operands.
339   case AffineExprKind::Add: {
340     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
341     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
342            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
343   }
344   // Checks divisibility by the given symbol for both operands. Consider the
345   // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
346   // this is a division by s1 and both the operands of modulo are divisible by
347   // s1 but it is not divisible by s1 always. The third argument is
348   // `AffineExprKind::Mod` for this reason.
349   case AffineExprKind::Mod: {
350     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
351     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
352                                AffineExprKind::Mod) &&
353            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
354                                AffineExprKind::Mod);
355   }
356   // Checks if any of the operand divisible by the given symbol.
357   case AffineExprKind::Mul: {
358     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
359     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
360            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
361   }
362   // Floordiv and ceildiv are divisible by the given symbol when the first
363   // operand is divisible, and the affine expression kind of the argument expr
364   // is same as the argument `opKind`. This can be inferred from commutative
365   // property of floordiv and ceildiv operations and are as follow:
366   // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
367   // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
368   // It will fail if operations are not same. For example:
369   // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
370   case AffineExprKind::FloorDiv:
371   case AffineExprKind::CeilDiv: {
372     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
373     if (opKind != expr.getKind())
374       return false;
375     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
376   }
377   }
378   llvm_unreachable("Unknown AffineExpr");
379 }
380 
381 /// Divides the given expression by the given symbol at position `symbolPos`. It
382 /// considers the divisibility condition is checked before calling itself. A
383 /// null expression is returned whenever the divisibility condition fails.
384 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
385                                  AffineExprKind opKind) {
386   // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
387   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
388           opKind == AffineExprKind::CeilDiv) &&
389          "unexpected opKind");
390   switch (expr.getKind()) {
391   case AffineExprKind::Constant:
392     if (expr.cast<AffineConstantExpr>().getValue() != 0)
393       return nullptr;
394     return getAffineConstantExpr(0, expr.getContext());
395   case AffineExprKind::DimId:
396     return nullptr;
397   case AffineExprKind::SymbolId:
398     return getAffineConstantExpr(1, expr.getContext());
399   // Dividing both operands by the given symbol.
400   case AffineExprKind::Add: {
401     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
402     return getAffineBinaryOpExpr(
403         expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
404         symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
405   }
406   // Dividing both operands by the given symbol.
407   case AffineExprKind::Mod: {
408     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
409     return getAffineBinaryOpExpr(
410         expr.getKind(),
411         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
412         symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
413   }
414   // Dividing any of the operand by the given symbol.
415   case AffineExprKind::Mul: {
416     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
417     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
418       return binaryExpr.getLHS() *
419              symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
420     return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
421            binaryExpr.getRHS();
422   }
423   // Dividing first operand only by the given symbol.
424   case AffineExprKind::FloorDiv:
425   case AffineExprKind::CeilDiv: {
426     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
427     return getAffineBinaryOpExpr(
428         expr.getKind(),
429         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
430         binaryExpr.getRHS());
431   }
432   }
433   llvm_unreachable("Unknown AffineExpr");
434 }
435 
436 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
437 /// operations when the second operand simplifies to a symbol and the first
438 /// operand is divisible by that symbol. It can be applied to any semi-affine
439 /// expression. Returned expression can either be a semi-affine or pure affine
440 /// expression.
441 static AffineExpr simplifySemiAffine(AffineExpr expr) {
442   switch (expr.getKind()) {
443   case AffineExprKind::Constant:
444   case AffineExprKind::DimId:
445   case AffineExprKind::SymbolId:
446     return expr;
447   case AffineExprKind::Add:
448   case AffineExprKind::Mul: {
449     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
450     return getAffineBinaryOpExpr(expr.getKind(),
451                                  simplifySemiAffine(binaryExpr.getLHS()),
452                                  simplifySemiAffine(binaryExpr.getRHS()));
453   }
454   // Check if the simplification of the second operand is a symbol, and the
455   // first operand is divisible by it. If the operation is a modulo, a constant
456   // zero expression is returned. In the case of floordiv and ceildiv, the
457   // symbol from the simplification of the second operand divides the first
458   // operand. Otherwise, simplification is not possible.
459   case AffineExprKind::FloorDiv:
460   case AffineExprKind::CeilDiv:
461   case AffineExprKind::Mod: {
462     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
463     AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
464     AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
465     AffineSymbolExpr symbolExpr =
466         simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
467     if (!symbolExpr)
468       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
469     unsigned symbolPos = symbolExpr.getPosition();
470     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
471       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
472     if (expr.getKind() == AffineExprKind::Mod)
473       return getAffineConstantExpr(0, expr.getContext());
474     return symbolicDivide(sLHS, symbolPos, expr.getKind());
475   }
476   }
477   llvm_unreachable("Unknown AffineExpr");
478 }
479 
480 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
481                                        MLIRContext *context) {
482   auto assignCtx = [context](AffineDimExprStorage *storage) {
483     storage->context = context;
484   };
485 
486   StorageUniquer &uniquer = context->getAffineUniquer();
487   return uniquer.get<AffineDimExprStorage>(
488       assignCtx, static_cast<unsigned>(kind), position);
489 }
490 
491 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
492   return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
493 }
494 
495 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
496     : AffineExpr(ptr) {}
497 unsigned AffineSymbolExpr::getPosition() const {
498   return static_cast<ImplType *>(expr)->position;
499 }
500 
501 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
502   return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
503   ;
504 }
505 
506 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
507     : AffineExpr(ptr) {}
508 int64_t AffineConstantExpr::getValue() const {
509   return static_cast<ImplType *>(expr)->constant;
510 }
511 
512 bool AffineExpr::operator==(int64_t v) const {
513   return *this == getAffineConstantExpr(v, getContext());
514 }
515 
516 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
517   auto assignCtx = [context](AffineConstantExprStorage *storage) {
518     storage->context = context;
519   };
520 
521   StorageUniquer &uniquer = context->getAffineUniquer();
522   return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
523 }
524 
525 /// Simplify add expression. Return nullptr if it can't be simplified.
526 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
527   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
528   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
529   // Fold if both LHS, RHS are a constant.
530   if (lhsConst && rhsConst)
531     return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
532                                  lhs.getContext());
533 
534   // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
535   // If only one of them is a symbolic expressions, make it the RHS.
536   if (lhs.isa<AffineConstantExpr>() ||
537       (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
538     return rhs + lhs;
539   }
540 
541   // At this point, if there was a constant, it would be on the right.
542 
543   // Addition with a zero is a noop, return the other input.
544   if (rhsConst) {
545     if (rhsConst.getValue() == 0)
546       return lhs;
547   }
548   // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
549   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
550   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
551     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
552       return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
553   }
554 
555   // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
556   // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
557   // respective multiplicands.
558   Optional<int64_t> rLhsConst, rRhsConst;
559   AffineExpr firstExpr, secondExpr;
560   AffineConstantExpr rLhsConstExpr;
561   auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
562   if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
563       (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
564     rLhsConst = rLhsConstExpr.getValue();
565     firstExpr = lBinOpExpr.getLHS();
566   } else {
567     rLhsConst = 1;
568     firstExpr = lhs;
569   }
570 
571   auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
572   AffineConstantExpr rRhsConstExpr;
573   if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
574       (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
575     rRhsConst = rRhsConstExpr.getValue();
576     secondExpr = rBinOpExpr.getLHS();
577   } else {
578     rRhsConst = 1;
579     secondExpr = rhs;
580   }
581 
582   if (rLhsConst && rRhsConst && firstExpr == secondExpr)
583     return getAffineBinaryOpExpr(
584         AffineExprKind::Mul, firstExpr,
585         getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(),
586                               lhs.getContext()));
587 
588   // When doing successive additions, bring constant to the right: turn (d0 + 2)
589   // + d1 into (d0 + d1) + 2.
590   if (lBin && lBin.getKind() == AffineExprKind::Add) {
591     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
592       return lBin.getLHS() + rhs + lrhs;
593     }
594   }
595 
596   // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
597   // q may be a constant or symbolic expression. This leads to a much more
598   // efficient form when 'c' is a power of two, and in general a more compact
599   // and readable form.
600 
601   // Process '(expr floordiv c) * (-c)'.
602   if (!rBinOpExpr)
603     return nullptr;
604 
605   auto lrhs = rBinOpExpr.getLHS();
606   auto rrhs = rBinOpExpr.getRHS();
607 
608   AffineExpr llrhs, rlrhs;
609 
610   // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
611   // symbolic expression.
612   auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
613   // Check rrhsConstOpExpr = -1.
614   auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>();
615   if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
616       lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
617     // Check llrhs = expr floordiv q.
618     llrhs = lrhsBinOpExpr.getLHS();
619     // Check rlrhs = q.
620     rlrhs = lrhsBinOpExpr.getRHS();
621     auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>();
622     if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
623       return nullptr;
624     if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
625       return lhs % rlrhs;
626   }
627 
628   // Process lrhs, which is 'expr floordiv c'.
629   AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
630   if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
631     return nullptr;
632 
633   llrhs = lrBinOpExpr.getLHS();
634   rlrhs = lrBinOpExpr.getRHS();
635 
636   if (lhs == llrhs && rlrhs == -rrhs) {
637     return lhs % rlrhs;
638   }
639   return nullptr;
640 }
641 
642 AffineExpr AffineExpr::operator+(int64_t v) const {
643   return *this + getAffineConstantExpr(v, getContext());
644 }
645 AffineExpr AffineExpr::operator+(AffineExpr other) const {
646   if (auto simplified = simplifyAdd(*this, other))
647     return simplified;
648 
649   StorageUniquer &uniquer = getContext()->getAffineUniquer();
650   return uniquer.get<AffineBinaryOpExprStorage>(
651       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
652 }
653 
654 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
655 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
656   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
657   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
658 
659   if (lhsConst && rhsConst)
660     return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
661                                  lhs.getContext());
662 
663   assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
664 
665   // Canonicalize the mul expression so that the constant/symbolic term is the
666   // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
667   // constant. (Note that a constant is trivially symbolic).
668   if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
669     // At least one of them has to be symbolic.
670     return rhs * lhs;
671   }
672 
673   // At this point, if there was a constant, it would be on the right.
674 
675   // Multiplication with a one is a noop, return the other input.
676   if (rhsConst) {
677     if (rhsConst.getValue() == 1)
678       return lhs;
679     // Multiplication with zero.
680     if (rhsConst.getValue() == 0)
681       return rhsConst;
682   }
683 
684   // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
685   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
686   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
687     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
688       return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
689   }
690 
691   // When doing successive multiplication, bring constant to the right: turn (d0
692   // * 2) * d1 into (d0 * d1) * 2.
693   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
694     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
695       return (lBin.getLHS() * rhs) * lrhs;
696     }
697   }
698 
699   return nullptr;
700 }
701 
702 AffineExpr AffineExpr::operator*(int64_t v) const {
703   return *this * getAffineConstantExpr(v, getContext());
704 }
705 AffineExpr AffineExpr::operator*(AffineExpr other) const {
706   if (auto simplified = simplifyMul(*this, other))
707     return simplified;
708 
709   StorageUniquer &uniquer = getContext()->getAffineUniquer();
710   return uniquer.get<AffineBinaryOpExprStorage>(
711       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
712 }
713 
714 // Unary minus, delegate to operator*.
715 AffineExpr AffineExpr::operator-() const {
716   return *this * getAffineConstantExpr(-1, getContext());
717 }
718 
719 // Delegate to operator+.
720 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
721 AffineExpr AffineExpr::operator-(AffineExpr other) const {
722   return *this + (-other);
723 }
724 
725 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
726   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
727   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
728 
729   // mlir floordiv by zero or negative numbers is undefined and preserved as is.
730   if (!rhsConst || rhsConst.getValue() < 1)
731     return nullptr;
732 
733   if (lhsConst)
734     return getAffineConstantExpr(
735         floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
736 
737   // Fold floordiv of a multiply with a constant that is a multiple of the
738   // divisor. Eg: (i * 128) floordiv 64 = i * 2.
739   if (rhsConst == 1)
740     return lhs;
741 
742   // Simplify (expr * const) floordiv divConst when expr is known to be a
743   // multiple of divConst.
744   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
745   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
746     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
747       // rhsConst is known to be a positive constant.
748       if (lrhs.getValue() % rhsConst.getValue() == 0)
749         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
750     }
751   }
752 
753   // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
754   // known to be a multiple of divConst.
755   if (lBin && lBin.getKind() == AffineExprKind::Add) {
756     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
757     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
758     // rhsConst is known to be a positive constant.
759     if (llhsDiv % rhsConst.getValue() == 0 ||
760         lrhsDiv % rhsConst.getValue() == 0)
761       return lBin.getLHS().floorDiv(rhsConst.getValue()) +
762              lBin.getRHS().floorDiv(rhsConst.getValue());
763   }
764 
765   return nullptr;
766 }
767 
768 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
769   return floorDiv(getAffineConstantExpr(v, getContext()));
770 }
771 AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
772   if (auto simplified = simplifyFloorDiv(*this, other))
773     return simplified;
774 
775   StorageUniquer &uniquer = getContext()->getAffineUniquer();
776   return uniquer.get<AffineBinaryOpExprStorage>(
777       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
778       other);
779 }
780 
781 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
782   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
783   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
784 
785   if (!rhsConst || rhsConst.getValue() < 1)
786     return nullptr;
787 
788   if (lhsConst)
789     return getAffineConstantExpr(
790         ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
791 
792   // Fold ceildiv of a multiply with a constant that is a multiple of the
793   // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
794   if (rhsConst.getValue() == 1)
795     return lhs;
796 
797   // Simplify (expr * const) ceildiv divConst when const is known to be a
798   // multiple of divConst.
799   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
800   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
801     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
802       // rhsConst is known to be a positive constant.
803       if (lrhs.getValue() % rhsConst.getValue() == 0)
804         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
805     }
806   }
807 
808   return nullptr;
809 }
810 
811 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
812   return ceilDiv(getAffineConstantExpr(v, getContext()));
813 }
814 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
815   if (auto simplified = simplifyCeilDiv(*this, other))
816     return simplified;
817 
818   StorageUniquer &uniquer = getContext()->getAffineUniquer();
819   return uniquer.get<AffineBinaryOpExprStorage>(
820       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
821       other);
822 }
823 
824 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
825   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
826   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
827 
828   // mod w.r.t zero or negative numbers is undefined and preserved as is.
829   if (!rhsConst || rhsConst.getValue() < 1)
830     return nullptr;
831 
832   if (lhsConst)
833     return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
834                                  lhs.getContext());
835 
836   // Fold modulo of an expression that is known to be a multiple of a constant
837   // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
838   // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
839   if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
840     return getAffineConstantExpr(0, lhs.getContext());
841 
842   // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
843   // known to be a multiple of divConst.
844   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
845   if (lBin && lBin.getKind() == AffineExprKind::Add) {
846     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
847     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
848     // rhsConst is known to be a positive constant.
849     if (llhsDiv % rhsConst.getValue() == 0)
850       return lBin.getRHS() % rhsConst.getValue();
851     if (lrhsDiv % rhsConst.getValue() == 0)
852       return lBin.getLHS() % rhsConst.getValue();
853   }
854 
855   // Simplify (e % a) % b to e % b when b evenly divides a
856   if (lBin && lBin.getKind() == AffineExprKind::Mod) {
857     auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
858     if (intermediate && intermediate.getValue() >= 1 &&
859         mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
860       return lBin.getLHS() % rhsConst.getValue();
861     }
862   }
863 
864   return nullptr;
865 }
866 
867 AffineExpr AffineExpr::operator%(uint64_t v) const {
868   return *this % getAffineConstantExpr(v, getContext());
869 }
870 AffineExpr AffineExpr::operator%(AffineExpr other) const {
871   if (auto simplified = simplifyMod(*this, other))
872     return simplified;
873 
874   StorageUniquer &uniquer = getContext()->getAffineUniquer();
875   return uniquer.get<AffineBinaryOpExprStorage>(
876       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
877 }
878 
879 AffineExpr AffineExpr::compose(AffineMap map) const {
880   SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
881                                              map.getResults().end());
882   return replaceDimsAndSymbols(dimReplacements, {});
883 }
884 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
885   expr.print(os);
886   return os;
887 }
888 
889 /// Constructs an affine expression from a flat ArrayRef. If there are local
890 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
891 /// products expression, `localExprs` is expected to have the AffineExpr
892 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
893 /// in the format [dims, symbols, locals, constant term].
894 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
895                                            unsigned numDims,
896                                            unsigned numSymbols,
897                                            ArrayRef<AffineExpr> localExprs,
898                                            MLIRContext *context) {
899   // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
900   assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
901          "unexpected number of local expressions");
902 
903   auto expr = getAffineConstantExpr(0, context);
904   // Dimensions and symbols.
905   for (unsigned j = 0; j < numDims + numSymbols; j++) {
906     if (flatExprs[j] == 0)
907       continue;
908     auto id = j < numDims ? getAffineDimExpr(j, context)
909                           : getAffineSymbolExpr(j - numDims, context);
910     expr = expr + id * flatExprs[j];
911   }
912 
913   // Local identifiers.
914   for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
915        j++) {
916     if (flatExprs[j] == 0)
917       continue;
918     auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
919     expr = expr + term;
920   }
921 
922   // Constant term.
923   int64_t constTerm = flatExprs[flatExprs.size() - 1];
924   if (constTerm != 0)
925     expr = expr + constTerm;
926   return expr;
927 }
928 
929 /// Constructs a semi-affine expression from a flat ArrayRef. If there are
930 /// local identifiers (neither dimensional nor symbolic) that appear in the sum
931 /// of products expression, `localExprs` is expected to have the AffineExprs for
932 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
933 /// the format [dims, symbols, locals, constant term]. The semi-affine
934 /// expression is constructed in the sorted order of dimension and symbol
935 /// position numbers. Note:  local expressions/ids are used for mod, div as well
936 /// as symbolic RHS terms for terms that are not pure affine.
937 static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
938                                                 unsigned numDims,
939                                                 unsigned numSymbols,
940                                                 ArrayRef<AffineExpr> localExprs,
941                                                 MLIRContext *context) {
942   assert(!flatExprs.empty() && "flatExprs cannot be empty");
943 
944   // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
945   assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
946          "unexpected number of local expressions");
947 
948   AffineExpr expr = getAffineConstantExpr(0, context);
949 
950   // We design indices as a pair which help us present the semi-affine map as
951   // sum of product where terms are sorted based on dimension or symbol
952   // position: <keyA, keyB> for expressions of the form dimension * symbol,
953   // where keyA is the position number of the dimension and keyB is the
954   // position number of the symbol. For dimensional expressions we set the index
955   // as (position number of the dimension, -1), as we want dimensional
956   // expressions to appear before symbolic and product of dimensional and
957   // symbolic expressions having the dimension with the same position number.
958   // For symbolic expression set the index as (position number of the symbol,
959   // maximum of last dimension and symbol position) number. For example, we want
960   // the expression we are constructing to look something like: d0 + d0 * s0 +
961   // s0 + d1*s1 + s1.
962 
963   // Stores the affine expression corresponding to a given index.
964   DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap;
965   // Stores the constant coefficient value corresponding to a given
966   // dimension, symbol or a non-pure affine expression stored in `localExprs`.
967   DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
968   // Stores the indices as defined above, and later sorted to produce
969   // the semi-affine expression in the desired form.
970   SmallVector<std::pair<unsigned, signed>, 8> indices;
971 
972   // Example: expression = d0 + d0 * s0 + 2 * s0.
973   // indices = [{0,-1}, {0, 0}, {0, 1}]
974   // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
975   // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
976 
977   // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
978   auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
979                       AffineExpr expr) {
980     assert(std::find(indices.begin(), indices.end(), index) == indices.end() &&
981            "Key is already present in indices vector and overwriting will "
982            "happen in `indexToExprMap` and `coefficients`!");
983 
984     indices.push_back(index);
985     coefficients.insert({index, coefficient});
986     indexToExprMap.insert({index, expr});
987   };
988 
989   // Design indices for dimensional or symbolic terms, and store the indices,
990   // constant coefficient corresponding to the indices in `coefficients` map,
991   // and affine expression corresponding to indices in `indexToExprMap` map.
992 
993   for (unsigned j = 0; j < numDims; ++j) {
994     if (flatExprs[j] == 0)
995       continue;
996     // For dimensional expressions we set the index as <position number of the
997     // dimension, 0>, as we want dimensional expressions to appear before
998     // symbolic ones and products of dimensional and symbolic expressions
999     // having the dimension with the same position number.
1000     std::pair<unsigned, signed> indexEntry(j, -1);
1001     addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
1002   }
1003   for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1004     if (flatExprs[j] == 0)
1005       continue;
1006     // For symbolic expression set the index as <position number
1007     // of the symbol, max(dimCount, symCount)> number,
1008     // as we want symbolic expressions with the same positional number to
1009     // appear after dimensional expressions having the same positional number.
1010     std::pair<unsigned, signed> indexEntry(j - numDims,
1011                                            std::max(numDims, numSymbols));
1012     addEntry(indexEntry, flatExprs[j],
1013              getAffineSymbolExpr(j - numDims, context));
1014   }
1015 
1016   // Denotes semi-affine product, modulo or division terms, which has been added
1017   // to the `indexToExpr` map.
1018   SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1019                                   false);
1020   unsigned lhsPos, rhsPos;
1021   // Construct indices for product terms involving dimension, symbol or constant
1022   // as lhs/rhs, and store the indices, constant coefficient corresponding to
1023   // the indices in `coefficients` map, and affine expression corresponding to
1024   // in indices in `indexToExprMap` map.
1025   for (const auto &it : llvm::enumerate(localExprs)) {
1026     AffineExpr expr = it.value();
1027     if (flatExprs[numDims + numSymbols + it.index()] == 0)
1028       continue;
1029     AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS();
1030     AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS();
1031     if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) &&
1032           (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() ||
1033            rhs.isa<AffineConstantExpr>()))) {
1034       continue;
1035     }
1036     if (rhs.isa<AffineConstantExpr>()) {
1037       // For product/modulo/division expressions, when rhs of modulo/division
1038       // expression is constant, we put 0 in place of keyB, because we want
1039       // them to appear earlier in the semi-affine expression we are
1040       // constructing. When rhs is constant, we place 0 in place of keyB.
1041       if (lhs.isa<AffineDimExpr>()) {
1042         lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1043         std::pair<unsigned, signed> indexEntry(lhsPos, -1);
1044         addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1045                  expr);
1046       } else {
1047         lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1048         std::pair<unsigned, signed> indexEntry(lhsPos,
1049                                                std::max(numDims, numSymbols));
1050         addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1051                  expr);
1052       }
1053     } else if (lhs.isa<AffineDimExpr>()) {
1054       // For product/modulo/division expressions having lhs as dimension and rhs
1055       // as symbol, we order the terms in the semi-affine expression based on
1056       // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1057       // where keyA is the position number of the dimension and keyB is the
1058       // position number of the symbol.
1059       lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1060       rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1061       std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1062       addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1063     } else {
1064       // For product/modulo/division expressions having both lhs and rhs as
1065       // symbol, we design indices as a pair: <keyA, keyB> for expressions
1066       // of the form dimension * symbol, where keyA is the position number of
1067       // the dimension and keyB is the position number of the symbol.
1068       lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1069       rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1070       std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1071       addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1072     }
1073     addedToMap[it.index()] = true;
1074   }
1075 
1076   // Constructing the simplified semi-affine sum of product/division/mod
1077   // expression from the flattened form in the desired sorted order of indices
1078   // of the various individual product/division/mod expressions.
1079   std::sort(indices.begin(), indices.end());
1080   for (const std::pair<unsigned, unsigned> index : indices) {
1081     assert(indexToExprMap.lookup(index) &&
1082            "cannot find key in `indexToExprMap` map");
1083     expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1084   }
1085 
1086   // Local identifiers.
1087   for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1088        j++) {
1089     // If the coefficient of the local expression is 0, continue as we need not
1090     // add it in out final expression.
1091     if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1092       continue;
1093     auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1094     expr = expr + term;
1095   }
1096 
1097   // Constant term.
1098   int64_t constTerm = flatExprs.back();
1099   if (constTerm != 0)
1100     expr = expr + constTerm;
1101   return expr;
1102 }
1103 
1104 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
1105                                                      unsigned numSymbols)
1106     : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1107   operandExprStack.reserve(8);
1108 }
1109 
1110 // In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1111 //
1112 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1113 // introduce a local variable p (= expr * symbolic_expr), and the affine
1114 // expression expr * symbolic_expr is added to `localExprs`.
1115 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
1116   assert(operandExprStack.size() >= 2);
1117   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1118   operandExprStack.pop_back();
1119   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1120 
1121   // Flatten semi-affine multiplication expressions by introducing a local
1122   // variable in place of the product; the affine expression
1123   // corresponding to the quantifier is added to `localExprs`.
1124   if (!expr.getRHS().isa<AffineConstantExpr>()) {
1125     MLIRContext *context = expr.getContext();
1126     AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
1127                                              localExprs, context);
1128     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1129                                              localExprs, context);
1130     addLocalVariableSemiAffine(a * b, lhs, lhs.size());
1131     return;
1132   }
1133 
1134   // Get the RHS constant.
1135   auto rhsConst = rhs[getConstantIndex()];
1136   for (unsigned i = 0, e = lhs.size(); i < e; i++) {
1137     lhs[i] *= rhsConst;
1138   }
1139 }
1140 
1141 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
1142   assert(operandExprStack.size() >= 2);
1143   const auto &rhs = operandExprStack.back();
1144   auto &lhs = operandExprStack[operandExprStack.size() - 2];
1145   assert(lhs.size() == rhs.size());
1146   // Update the LHS in place.
1147   for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1148     lhs[i] += rhs[i];
1149   }
1150   // Pop off the RHS.
1151   operandExprStack.pop_back();
1152 }
1153 
1154 //
1155 // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
1156 //
1157 // A mod expression "expr mod c" is thus flattened by introducing a new local
1158 // variable q (= expr floordiv c), such that expr mod c is replaced with
1159 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1160 //
1161 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1162 // introduce a local variable m (= expr mod symbolic_expr), and the affine
1163 // expression expr mod symbolic_expr is added to `localExprs`.
1164 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1165   assert(operandExprStack.size() >= 2);
1166 
1167   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1168   operandExprStack.pop_back();
1169   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1170   MLIRContext *context = expr.getContext();
1171 
1172   // Flatten semi affine modulo expressions by introducing a local
1173   // variable in place of the modulo value, and the affine expression
1174   // corresponding to the quantifier is added to `localExprs`.
1175   if (!expr.getRHS().isa<AffineConstantExpr>()) {
1176     AffineExpr dividendExpr = getAffineExprFromFlatForm(
1177         lhs, numDims, numSymbols, localExprs, context);
1178     AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1179                                                        localExprs, context);
1180     AffineExpr modExpr = dividendExpr % divisorExpr;
1181     addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
1182     return;
1183   }
1184 
1185   int64_t rhsConst = rhs[getConstantIndex()];
1186   // TODO: handle modulo by zero case when this issue is fixed
1187   // at the other places in the IR.
1188   assert(rhsConst > 0 && "RHS constant has to be positive");
1189 
1190   // Check if the LHS expression is a multiple of modulo factor.
1191   unsigned i, e;
1192   for (i = 0, e = lhs.size(); i < e; i++)
1193     if (lhs[i] % rhsConst != 0)
1194       break;
1195   // If yes, modulo expression here simplifies to zero.
1196   if (i == lhs.size()) {
1197     std::fill(lhs.begin(), lhs.end(), 0);
1198     return;
1199   }
1200 
1201   // Add a local variable for the quotient, i.e., expr % c is replaced by
1202   // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1203   // the GCD of expr and c.
1204   SmallVector<int64_t, 8> floorDividend(lhs);
1205   uint64_t gcd = rhsConst;
1206   for (unsigned i = 0, e = lhs.size(); i < e; i++)
1207     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1208   // Simplify the numerator and the denominator.
1209   if (gcd != 1) {
1210     for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
1211       floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
1212   }
1213   int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1214 
1215   // Construct the AffineExpr form of the floordiv to store in localExprs.
1216 
1217   AffineExpr dividendExpr = getAffineExprFromFlatForm(
1218       floorDividend, numDims, numSymbols, localExprs, context);
1219   AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1220   AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1221   int loc;
1222   if ((loc = findLocalId(floorDivExpr)) == -1) {
1223     addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1224     // Set result at top of stack to "lhs - rhsConst * q".
1225     lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1226   } else {
1227     // Reuse the existing local id.
1228     lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1229   }
1230 }
1231 
1232 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
1233   visitDivExpr(expr, /*isCeil=*/true);
1234 }
1235 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
1236   visitDivExpr(expr, /*isCeil=*/false);
1237 }
1238 
1239 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
1240   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1241   auto &eq = operandExprStack.back();
1242   assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1243   eq[getDimStartIndex() + expr.getPosition()] = 1;
1244 }
1245 
1246 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
1247   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1248   auto &eq = operandExprStack.back();
1249   assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1250   eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1251 }
1252 
1253 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
1254   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1255   auto &eq = operandExprStack.back();
1256   eq[getConstantIndex()] = expr.getValue();
1257 }
1258 
1259 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1260     AffineExpr expr, SmallVectorImpl<int64_t> &result,
1261     unsigned long resultSize) {
1262   assert(result.size() == resultSize &&
1263          "`result` vector passed is not of correct size");
1264   int loc;
1265   if ((loc = findLocalId(expr)) == -1)
1266     addLocalIdSemiAffine(expr);
1267   std::fill(result.begin(), result.end(), 0);
1268   if (loc == -1)
1269     result[getLocalVarStartIndex() + numLocals - 1] = 1;
1270   else
1271     result[getLocalVarStartIndex() + loc] = 1;
1272 }
1273 
1274 // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
1275 // A floordiv is thus flattened by introducing a new local variable q, and
1276 // replacing that expression with 'q' while adding the constraints
1277 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1278 // FlatAffineConstraints::addLocalFloorDiv).
1279 //
1280 // A ceildiv is similarly flattened:
1281 // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
1282 //
1283 // In case of semi affine division expressions, t = expr floordiv symbolic_expr
1284 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1285 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1286 // `localExprs`.
1287 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1288                                              bool isCeil) {
1289   assert(operandExprStack.size() >= 2);
1290 
1291   MLIRContext *context = expr.getContext();
1292   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1293   operandExprStack.pop_back();
1294   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1295 
1296   // Flatten semi affine division expressions by introducing a local
1297   // variable in place of the quotient, and the affine expression corresponding
1298   // to the quantifier is added to `localExprs`.
1299   if (!expr.getRHS().isa<AffineConstantExpr>()) {
1300     AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
1301                                              localExprs, context);
1302     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1303                                              localExprs, context);
1304     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1305     addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
1306     return;
1307   }
1308 
1309   // This is a pure affine expr; the RHS is a positive constant.
1310   int64_t rhsConst = rhs[getConstantIndex()];
1311   // TODO: handle division by zero at the same time the issue is
1312   // fixed at other places.
1313   assert(rhsConst > 0 && "RHS constant has to be positive");
1314 
1315   // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1316   // common divisors of the numerator and denominator.
1317   uint64_t gcd = std::abs(rhsConst);
1318   for (unsigned i = 0, e = lhs.size(); i < e; i++)
1319     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1320   // Simplify the numerator and the denominator.
1321   if (gcd != 1) {
1322     for (unsigned i = 0, e = lhs.size(); i < e; i++)
1323       lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1324   }
1325   int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1326   // If the divisor becomes 1, the updated LHS is the result. (The
1327   // divisor can't be negative since rhsConst is positive).
1328   if (divisor == 1)
1329     return;
1330 
1331   // If the divisor cannot be simplified to one, we will have to retain
1332   // the ceil/floor expr (simplified up until here). Add an existential
1333   // quantifier to express its result, i.e., expr1 div expr2 is replaced
1334   // by a new identifier, q.
1335   AffineExpr a =
1336       getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context);
1337   AffineExpr b = getAffineConstantExpr(divisor, context);
1338 
1339   int loc;
1340   AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1341   if ((loc = findLocalId(divExpr)) == -1) {
1342     if (!isCeil) {
1343       SmallVector<int64_t, 8> dividend(lhs);
1344       addLocalFloorDivId(dividend, divisor, divExpr);
1345     } else {
1346       // lhs ceildiv c <=>  (lhs + c - 1) floordiv c
1347       SmallVector<int64_t, 8> dividend(lhs);
1348       dividend.back() += divisor - 1;
1349       addLocalFloorDivId(dividend, divisor, divExpr);
1350     }
1351   }
1352   // Set the expression on stack to the local var introduced to capture the
1353   // result of the division (floor or ceil).
1354   std::fill(lhs.begin(), lhs.end(), 0);
1355   if (loc == -1)
1356     lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1357   else
1358     lhs[getLocalVarStartIndex() + loc] = 1;
1359 }
1360 
1361 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1362 // The local identifier added is always a floordiv of a pure add/mul affine
1363 // function of other identifiers, coefficients of which are specified in
1364 // dividend and with respect to a positive constant divisor. localExpr is the
1365 // simplified tree expression (AffineExpr) corresponding to the quantifier.
1366 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1367                                                    int64_t divisor,
1368                                                    AffineExpr localExpr) {
1369   assert(divisor > 0 && "positive constant divisor expected");
1370   for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1371     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1372   localExprs.push_back(localExpr);
1373   numLocals++;
1374   // dividend and divisor are not used here; an override of this method uses it.
1375 }
1376 
1377 void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) {
1378   for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1379     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1380   localExprs.push_back(localExpr);
1381   ++numLocals;
1382 }
1383 
1384 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1385   SmallVectorImpl<AffineExpr>::iterator it;
1386   if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1387     return -1;
1388   return it - localExprs.begin();
1389 }
1390 
1391 /// Simplify the affine expression by flattening it and reconstructing it.
1392 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1393                                     unsigned numSymbols) {
1394   // Simplify semi-affine expressions separately.
1395   if (!expr.isPureAffine())
1396     expr = simplifySemiAffine(expr);
1397 
1398   SimpleAffineExprFlattener flattener(numDims, numSymbols);
1399   flattener.walkPostOrder(expr);
1400   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1401   if (!expr.isPureAffine() &&
1402       expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1403                                         flattener.localExprs,
1404                                         expr.getContext()))
1405     return expr;
1406   AffineExpr simplifiedExpr =
1407       expr.isPureAffine()
1408           ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1409                                       flattener.localExprs, expr.getContext())
1410           : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1411                                           flattener.localExprs,
1412                                           expr.getContext());
1413 
1414   flattener.operandExprStack.pop_back();
1415   assert(flattener.operandExprStack.empty());
1416   return simplifiedExpr;
1417 }
1418