1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "mlir/IR/AffineExpr.h"
12 #include "AffineExprDetail.h"
13 #include "mlir/IR/AffineExprVisitor.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/Support/MathExtras.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 MLIRContext *AffineExpr::getContext() const { return expr->context; }
24 
25 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
26 
27 /// Walk all of the AffineExprs in this subgraph in postorder.
28 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
29   struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
30     std::function<void(AffineExpr)> callback;
31 
32     AffineExprWalker(std::function<void(AffineExpr)> callback)
33         : callback(std::move(callback)) {}
34 
35     void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
36     void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
37     void visitDimExpr(AffineDimExpr expr) { callback(expr); }
38     void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
39   };
40 
41   AffineExprWalker(std::move(callback)).walkPostOrder(*this);
42 }
43 
44 // Dispatch affine expression construction based on kind.
45 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
46                                        AffineExpr rhs) {
47   if (kind == AffineExprKind::Add)
48     return lhs + rhs;
49   if (kind == AffineExprKind::Mul)
50     return lhs * rhs;
51   if (kind == AffineExprKind::FloorDiv)
52     return lhs.floorDiv(rhs);
53   if (kind == AffineExprKind::CeilDiv)
54     return lhs.ceilDiv(rhs);
55   if (kind == AffineExprKind::Mod)
56     return lhs % rhs;
57 
58   llvm_unreachable("unknown binary operation on affine expressions");
59 }
60 
61 /// This method substitutes any uses of dimensions and symbols (e.g.
62 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
63 AffineExpr
64 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
65                                   ArrayRef<AffineExpr> symReplacements) const {
66   switch (getKind()) {
67   case AffineExprKind::Constant:
68     return *this;
69   case AffineExprKind::DimId: {
70     unsigned dimId = cast<AffineDimExpr>().getPosition();
71     if (dimId >= dimReplacements.size())
72       return *this;
73     return dimReplacements[dimId];
74   }
75   case AffineExprKind::SymbolId: {
76     unsigned symId = cast<AffineSymbolExpr>().getPosition();
77     if (symId >= symReplacements.size())
78       return *this;
79     return symReplacements[symId];
80   }
81   case AffineExprKind::Add:
82   case AffineExprKind::Mul:
83   case AffineExprKind::FloorDiv:
84   case AffineExprKind::CeilDiv:
85   case AffineExprKind::Mod:
86     auto binOp = cast<AffineBinaryOpExpr>();
87     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
88     auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
89     auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
90     if (newLHS == lhs && newRHS == rhs)
91       return *this;
92     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
93   }
94   llvm_unreachable("Unknown AffineExpr");
95 }
96 
97 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
98   return replaceDimsAndSymbols(dimReplacements, {});
99 }
100 
101 AffineExpr
102 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
103   return replaceDimsAndSymbols({}, symReplacements);
104 }
105 
106 /// Replace dims[offset ... numDims)
107 /// by dims[offset + shift ... shift + numDims).
108 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
109                                  unsigned offset) const {
110   SmallVector<AffineExpr, 4> dims;
111   for (unsigned idx = 0; idx < offset; ++idx)
112     dims.push_back(getAffineDimExpr(idx, getContext()));
113   for (unsigned idx = offset; idx < numDims; ++idx)
114     dims.push_back(getAffineDimExpr(idx + shift, getContext()));
115   return replaceDimsAndSymbols(dims, {});
116 }
117 
118 /// Replace symbols[offset ... numSymbols)
119 /// by symbols[offset + shift ... shift + numSymbols).
120 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
121                                     unsigned offset) const {
122   SmallVector<AffineExpr, 4> symbols;
123   for (unsigned idx = 0; idx < offset; ++idx)
124     symbols.push_back(getAffineSymbolExpr(idx, getContext()));
125   for (unsigned idx = offset; idx < numSymbols; ++idx)
126     symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
127   return replaceDimsAndSymbols({}, symbols);
128 }
129 
130 /// Sparse replace method. Return the modified expression tree.
131 AffineExpr
132 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
133   auto it = map.find(*this);
134   if (it != map.end())
135     return it->second;
136   switch (getKind()) {
137   default:
138     return *this;
139   case AffineExprKind::Add:
140   case AffineExprKind::Mul:
141   case AffineExprKind::FloorDiv:
142   case AffineExprKind::CeilDiv:
143   case AffineExprKind::Mod:
144     auto binOp = cast<AffineBinaryOpExpr>();
145     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
146     auto newLHS = lhs.replace(map);
147     auto newRHS = rhs.replace(map);
148     if (newLHS == lhs && newRHS == rhs)
149       return *this;
150     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
151   }
152   llvm_unreachable("Unknown AffineExpr");
153 }
154 
155 /// Sparse replace method. Return the modified expression tree.
156 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
157   DenseMap<AffineExpr, AffineExpr> map;
158   map.insert(std::make_pair(expr, replacement));
159   return replace(map);
160 }
161 /// Returns true if this expression is made out of only symbols and
162 /// constants (no dimensional identifiers).
163 bool AffineExpr::isSymbolicOrConstant() const {
164   switch (getKind()) {
165   case AffineExprKind::Constant:
166     return true;
167   case AffineExprKind::DimId:
168     return false;
169   case AffineExprKind::SymbolId:
170     return true;
171 
172   case AffineExprKind::Add:
173   case AffineExprKind::Mul:
174   case AffineExprKind::FloorDiv:
175   case AffineExprKind::CeilDiv:
176   case AffineExprKind::Mod: {
177     auto expr = this->cast<AffineBinaryOpExpr>();
178     return expr.getLHS().isSymbolicOrConstant() &&
179            expr.getRHS().isSymbolicOrConstant();
180   }
181   }
182   llvm_unreachable("Unknown AffineExpr");
183 }
184 
185 /// Returns true if this is a pure affine expression, i.e., multiplication,
186 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
187 bool AffineExpr::isPureAffine() const {
188   switch (getKind()) {
189   case AffineExprKind::SymbolId:
190   case AffineExprKind::DimId:
191   case AffineExprKind::Constant:
192     return true;
193   case AffineExprKind::Add: {
194     auto op = cast<AffineBinaryOpExpr>();
195     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
196   }
197 
198   case AffineExprKind::Mul: {
199     // TODO: Canonicalize the constants in binary operators to the RHS when
200     // possible, allowing this to merge into the next case.
201     auto op = cast<AffineBinaryOpExpr>();
202     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
203            (op.getLHS().template isa<AffineConstantExpr>() ||
204             op.getRHS().template isa<AffineConstantExpr>());
205   }
206   case AffineExprKind::FloorDiv:
207   case AffineExprKind::CeilDiv:
208   case AffineExprKind::Mod: {
209     auto op = cast<AffineBinaryOpExpr>();
210     return op.getLHS().isPureAffine() &&
211            op.getRHS().template isa<AffineConstantExpr>();
212   }
213   }
214   llvm_unreachable("Unknown AffineExpr");
215 }
216 
217 // Returns the greatest known integral divisor of this affine expression.
218 int64_t AffineExpr::getLargestKnownDivisor() const {
219   AffineBinaryOpExpr binExpr(nullptr);
220   switch (getKind()) {
221   case AffineExprKind::CeilDiv:
222     LLVM_FALLTHROUGH;
223   case AffineExprKind::DimId:
224   case AffineExprKind::FloorDiv:
225   case AffineExprKind::SymbolId:
226     return 1;
227   case AffineExprKind::Constant:
228     return std::abs(this->cast<AffineConstantExpr>().getValue());
229   case AffineExprKind::Mul: {
230     binExpr = this->cast<AffineBinaryOpExpr>();
231     return binExpr.getLHS().getLargestKnownDivisor() *
232            binExpr.getRHS().getLargestKnownDivisor();
233   }
234   case AffineExprKind::Add:
235     LLVM_FALLTHROUGH;
236   case AffineExprKind::Mod: {
237     binExpr = cast<AffineBinaryOpExpr>();
238     return llvm::GreatestCommonDivisor64(
239         binExpr.getLHS().getLargestKnownDivisor(),
240         binExpr.getRHS().getLargestKnownDivisor());
241   }
242   }
243   llvm_unreachable("Unknown AffineExpr");
244 }
245 
246 bool AffineExpr::isMultipleOf(int64_t factor) const {
247   AffineBinaryOpExpr binExpr(nullptr);
248   uint64_t l, u;
249   switch (getKind()) {
250   case AffineExprKind::SymbolId:
251     LLVM_FALLTHROUGH;
252   case AffineExprKind::DimId:
253     return factor * factor == 1;
254   case AffineExprKind::Constant:
255     return cast<AffineConstantExpr>().getValue() % factor == 0;
256   case AffineExprKind::Mul: {
257     binExpr = cast<AffineBinaryOpExpr>();
258     // It's probably not worth optimizing this further (to not traverse the
259     // whole sub-tree under - it that would require a version of isMultipleOf
260     // that on a 'false' return also returns the largest known divisor).
261     return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
262            (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
263            (l * u) % factor == 0;
264   }
265   case AffineExprKind::Add:
266   case AffineExprKind::FloorDiv:
267   case AffineExprKind::CeilDiv:
268   case AffineExprKind::Mod: {
269     binExpr = cast<AffineBinaryOpExpr>();
270     return llvm::GreatestCommonDivisor64(
271                binExpr.getLHS().getLargestKnownDivisor(),
272                binExpr.getRHS().getLargestKnownDivisor()) %
273                factor ==
274            0;
275   }
276   }
277   llvm_unreachable("Unknown AffineExpr");
278 }
279 
280 bool AffineExpr::isFunctionOfDim(unsigned position) const {
281   if (getKind() == AffineExprKind::DimId) {
282     return *this == mlir::getAffineDimExpr(position, getContext());
283   }
284   if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
285     return expr.getLHS().isFunctionOfDim(position) ||
286            expr.getRHS().isFunctionOfDim(position);
287   }
288   return false;
289 }
290 
291 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
292   if (getKind() == AffineExprKind::SymbolId) {
293     return *this == mlir::getAffineSymbolExpr(position, getContext());
294   }
295   if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
296     return expr.getLHS().isFunctionOfSymbol(position) ||
297            expr.getRHS().isFunctionOfSymbol(position);
298   }
299   return false;
300 }
301 
302 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
303     : AffineExpr(ptr) {}
304 AffineExpr AffineBinaryOpExpr::getLHS() const {
305   return static_cast<ImplType *>(expr)->lhs;
306 }
307 AffineExpr AffineBinaryOpExpr::getRHS() const {
308   return static_cast<ImplType *>(expr)->rhs;
309 }
310 
311 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
312 unsigned AffineDimExpr::getPosition() const {
313   return static_cast<ImplType *>(expr)->position;
314 }
315 
316 /// Returns true if the expression is divisible by the given symbol with
317 /// position `symbolPos`. The argument `opKind` specifies here what kind of
318 /// division or mod operation called this division. It helps in implementing the
319 /// commutative property of the floordiv and ceildiv operations. If the argument
320 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
321 /// operation, then the commutative property can be used otherwise, the floordiv
322 /// operation is not divisible. The same argument holds for ceildiv operation.
323 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
324                                 AffineExprKind opKind) {
325   // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
326   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
327           opKind == AffineExprKind::CeilDiv) &&
328          "unexpected opKind");
329   switch (expr.getKind()) {
330   case AffineExprKind::Constant:
331     return expr.cast<AffineConstantExpr>().getValue() == 0;
332   case AffineExprKind::DimId:
333     return false;
334   case AffineExprKind::SymbolId:
335     return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
336   // Checks divisibility by the given symbol for both operands.
337   case AffineExprKind::Add: {
338     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
339     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
340            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
341   }
342   // Checks divisibility by the given symbol for both operands. Consider the
343   // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
344   // this is a division by s1 and both the operands of modulo are divisible by
345   // s1 but it is not divisible by s1 always. The third argument is
346   // `AffineExprKind::Mod` for this reason.
347   case AffineExprKind::Mod: {
348     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
349     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
350                                AffineExprKind::Mod) &&
351            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
352                                AffineExprKind::Mod);
353   }
354   // Checks if any of the operand divisible by the given symbol.
355   case AffineExprKind::Mul: {
356     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
357     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
358            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
359   }
360   // Floordiv and ceildiv are divisible by the given symbol when the first
361   // operand is divisible, and the affine expression kind of the argument expr
362   // is same as the argument `opKind`. This can be inferred from commutative
363   // property of floordiv and ceildiv operations and are as follow:
364   // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
365   // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
366   // It will fail if operations are not same. For example:
367   // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
368   case AffineExprKind::FloorDiv:
369   case AffineExprKind::CeilDiv: {
370     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
371     if (opKind != expr.getKind())
372       return false;
373     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
374   }
375   }
376   llvm_unreachable("Unknown AffineExpr");
377 }
378 
379 /// Divides the given expression by the given symbol at position `symbolPos`. It
380 /// considers the divisibility condition is checked before calling itself. A
381 /// null expression is returned whenever the divisibility condition fails.
382 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
383                                  AffineExprKind opKind) {
384   // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
385   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
386           opKind == AffineExprKind::CeilDiv) &&
387          "unexpected opKind");
388   switch (expr.getKind()) {
389   case AffineExprKind::Constant:
390     if (expr.cast<AffineConstantExpr>().getValue() != 0)
391       return nullptr;
392     return getAffineConstantExpr(0, expr.getContext());
393   case AffineExprKind::DimId:
394     return nullptr;
395   case AffineExprKind::SymbolId:
396     return getAffineConstantExpr(1, expr.getContext());
397   // Dividing both operands by the given symbol.
398   case AffineExprKind::Add: {
399     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
400     return getAffineBinaryOpExpr(
401         expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
402         symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
403   }
404   // Dividing both operands by the given symbol.
405   case AffineExprKind::Mod: {
406     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
407     return getAffineBinaryOpExpr(
408         expr.getKind(),
409         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
410         symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
411   }
412   // Dividing any of the operand by the given symbol.
413   case AffineExprKind::Mul: {
414     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
415     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
416       return binaryExpr.getLHS() *
417              symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
418     return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
419            binaryExpr.getRHS();
420   }
421   // Dividing first operand only by the given symbol.
422   case AffineExprKind::FloorDiv:
423   case AffineExprKind::CeilDiv: {
424     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
425     return getAffineBinaryOpExpr(
426         expr.getKind(),
427         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
428         binaryExpr.getRHS());
429   }
430   }
431   llvm_unreachable("Unknown AffineExpr");
432 }
433 
434 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
435 /// operations when the second operand simplifies to a symbol and the first
436 /// operand is divisible by that symbol. It can be applied to any semi-affine
437 /// expression. Returned expression can either be a semi-affine or pure affine
438 /// expression.
439 static AffineExpr simplifySemiAffine(AffineExpr expr) {
440   switch (expr.getKind()) {
441   case AffineExprKind::Constant:
442   case AffineExprKind::DimId:
443   case AffineExprKind::SymbolId:
444     return expr;
445   case AffineExprKind::Add:
446   case AffineExprKind::Mul: {
447     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
448     return getAffineBinaryOpExpr(expr.getKind(),
449                                  simplifySemiAffine(binaryExpr.getLHS()),
450                                  simplifySemiAffine(binaryExpr.getRHS()));
451   }
452   // Check if the simplification of the second operand is a symbol, and the
453   // first operand is divisible by it. If the operation is a modulo, a constant
454   // zero expression is returned. In the case of floordiv and ceildiv, the
455   // symbol from the simplification of the second operand divides the first
456   // operand. Otherwise, simplification is not possible.
457   case AffineExprKind::FloorDiv:
458   case AffineExprKind::CeilDiv:
459   case AffineExprKind::Mod: {
460     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
461     AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
462     AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
463     AffineSymbolExpr symbolExpr =
464         simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
465     if (!symbolExpr)
466       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
467     unsigned symbolPos = symbolExpr.getPosition();
468     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
469       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
470     if (expr.getKind() == AffineExprKind::Mod)
471       return getAffineConstantExpr(0, expr.getContext());
472     return symbolicDivide(sLHS, symbolPos, expr.getKind());
473   }
474   }
475   llvm_unreachable("Unknown AffineExpr");
476 }
477 
478 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
479                                        MLIRContext *context) {
480   auto assignCtx = [context](AffineDimExprStorage *storage) {
481     storage->context = context;
482   };
483 
484   StorageUniquer &uniquer = context->getAffineUniquer();
485   return uniquer.get<AffineDimExprStorage>(
486       assignCtx, static_cast<unsigned>(kind), position);
487 }
488 
489 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
490   return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
491 }
492 
493 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
494     : AffineExpr(ptr) {}
495 unsigned AffineSymbolExpr::getPosition() const {
496   return static_cast<ImplType *>(expr)->position;
497 }
498 
499 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
500   return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
501   ;
502 }
503 
504 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
505     : AffineExpr(ptr) {}
506 int64_t AffineConstantExpr::getValue() const {
507   return static_cast<ImplType *>(expr)->constant;
508 }
509 
510 bool AffineExpr::operator==(int64_t v) const {
511   return *this == getAffineConstantExpr(v, getContext());
512 }
513 
514 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
515   auto assignCtx = [context](AffineConstantExprStorage *storage) {
516     storage->context = context;
517   };
518 
519   StorageUniquer &uniquer = context->getAffineUniquer();
520   return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
521 }
522 
523 /// Simplify add expression. Return nullptr if it can't be simplified.
524 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
525   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
526   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
527   // Fold if both LHS, RHS are a constant.
528   if (lhsConst && rhsConst)
529     return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
530                                  lhs.getContext());
531 
532   // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
533   // If only one of them is a symbolic expressions, make it the RHS.
534   if (lhs.isa<AffineConstantExpr>() ||
535       (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
536     return rhs + lhs;
537   }
538 
539   // At this point, if there was a constant, it would be on the right.
540 
541   // Addition with a zero is a noop, return the other input.
542   if (rhsConst) {
543     if (rhsConst.getValue() == 0)
544       return lhs;
545   }
546   // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
547   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
548   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
549     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
550       return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
551   }
552 
553   // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
554   // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
555   // respective multiplicands.
556   Optional<int64_t> rLhsConst, rRhsConst;
557   AffineExpr firstExpr, secondExpr;
558   AffineConstantExpr rLhsConstExpr;
559   auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
560   if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
561       (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
562     rLhsConst = rLhsConstExpr.getValue();
563     firstExpr = lBinOpExpr.getLHS();
564   } else {
565     rLhsConst = 1;
566     firstExpr = lhs;
567   }
568 
569   auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
570   AffineConstantExpr rRhsConstExpr;
571   if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
572       (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
573     rRhsConst = rRhsConstExpr.getValue();
574     secondExpr = rBinOpExpr.getLHS();
575   } else {
576     rRhsConst = 1;
577     secondExpr = rhs;
578   }
579 
580   if (rLhsConst && rRhsConst && firstExpr == secondExpr)
581     return getAffineBinaryOpExpr(
582         AffineExprKind::Mul, firstExpr,
583         getAffineConstantExpr(rLhsConst.getValue() + rRhsConst.getValue(),
584                               lhs.getContext()));
585 
586   // When doing successive additions, bring constant to the right: turn (d0 + 2)
587   // + d1 into (d0 + d1) + 2.
588   if (lBin && lBin.getKind() == AffineExprKind::Add) {
589     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
590       return lBin.getLHS() + rhs + lrhs;
591     }
592   }
593 
594   // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
595   // q may be a constant or symbolic expression. This leads to a much more
596   // efficient form when 'c' is a power of two, and in general a more compact
597   // and readable form.
598 
599   // Process '(expr floordiv c) * (-c)'.
600   if (!rBinOpExpr)
601     return nullptr;
602 
603   auto lrhs = rBinOpExpr.getLHS();
604   auto rrhs = rBinOpExpr.getRHS();
605 
606   AffineExpr llrhs, rlrhs;
607 
608   // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
609   // symbolic expression.
610   auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
611   // Check rrhsConstOpExpr = -1.
612   auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>();
613   if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
614       lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
615     // Check llrhs = expr floordiv q.
616     llrhs = lrhsBinOpExpr.getLHS();
617     // Check rlrhs = q.
618     rlrhs = lrhsBinOpExpr.getRHS();
619     auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>();
620     if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
621       return nullptr;
622     if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
623       return lhs % rlrhs;
624   }
625 
626   // Process lrhs, which is 'expr floordiv c'.
627   AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
628   if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
629     return nullptr;
630 
631   llrhs = lrBinOpExpr.getLHS();
632   rlrhs = lrBinOpExpr.getRHS();
633 
634   if (lhs == llrhs && rlrhs == -rrhs) {
635     return lhs % rlrhs;
636   }
637   return nullptr;
638 }
639 
640 AffineExpr AffineExpr::operator+(int64_t v) const {
641   return *this + getAffineConstantExpr(v, getContext());
642 }
643 AffineExpr AffineExpr::operator+(AffineExpr other) const {
644   if (auto simplified = simplifyAdd(*this, other))
645     return simplified;
646 
647   StorageUniquer &uniquer = getContext()->getAffineUniquer();
648   return uniquer.get<AffineBinaryOpExprStorage>(
649       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
650 }
651 
652 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
653 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
654   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
655   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
656 
657   if (lhsConst && rhsConst)
658     return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
659                                  lhs.getContext());
660 
661   assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
662 
663   // Canonicalize the mul expression so that the constant/symbolic term is the
664   // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
665   // constant. (Note that a constant is trivially symbolic).
666   if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
667     // At least one of them has to be symbolic.
668     return rhs * lhs;
669   }
670 
671   // At this point, if there was a constant, it would be on the right.
672 
673   // Multiplication with a one is a noop, return the other input.
674   if (rhsConst) {
675     if (rhsConst.getValue() == 1)
676       return lhs;
677     // Multiplication with zero.
678     if (rhsConst.getValue() == 0)
679       return rhsConst;
680   }
681 
682   // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
683   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
684   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
685     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
686       return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
687   }
688 
689   // When doing successive multiplication, bring constant to the right: turn (d0
690   // * 2) * d1 into (d0 * d1) * 2.
691   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
692     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
693       return (lBin.getLHS() * rhs) * lrhs;
694     }
695   }
696 
697   return nullptr;
698 }
699 
700 AffineExpr AffineExpr::operator*(int64_t v) const {
701   return *this * getAffineConstantExpr(v, getContext());
702 }
703 AffineExpr AffineExpr::operator*(AffineExpr other) const {
704   if (auto simplified = simplifyMul(*this, other))
705     return simplified;
706 
707   StorageUniquer &uniquer = getContext()->getAffineUniquer();
708   return uniquer.get<AffineBinaryOpExprStorage>(
709       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
710 }
711 
712 // Unary minus, delegate to operator*.
713 AffineExpr AffineExpr::operator-() const {
714   return *this * getAffineConstantExpr(-1, getContext());
715 }
716 
717 // Delegate to operator+.
718 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
719 AffineExpr AffineExpr::operator-(AffineExpr other) const {
720   return *this + (-other);
721 }
722 
723 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
724   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
725   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
726 
727   // mlir floordiv by zero or negative numbers is undefined and preserved as is.
728   if (!rhsConst || rhsConst.getValue() < 1)
729     return nullptr;
730 
731   if (lhsConst)
732     return getAffineConstantExpr(
733         floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
734 
735   // Fold floordiv of a multiply with a constant that is a multiple of the
736   // divisor. Eg: (i * 128) floordiv 64 = i * 2.
737   if (rhsConst == 1)
738     return lhs;
739 
740   // Simplify (expr * const) floordiv divConst when expr is known to be a
741   // multiple of divConst.
742   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
743   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
744     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
745       // rhsConst is known to be a positive constant.
746       if (lrhs.getValue() % rhsConst.getValue() == 0)
747         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
748     }
749   }
750 
751   // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
752   // known to be a multiple of divConst.
753   if (lBin && lBin.getKind() == AffineExprKind::Add) {
754     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
755     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
756     // rhsConst is known to be a positive constant.
757     if (llhsDiv % rhsConst.getValue() == 0 ||
758         lrhsDiv % rhsConst.getValue() == 0)
759       return lBin.getLHS().floorDiv(rhsConst.getValue()) +
760              lBin.getRHS().floorDiv(rhsConst.getValue());
761   }
762 
763   return nullptr;
764 }
765 
766 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
767   return floorDiv(getAffineConstantExpr(v, getContext()));
768 }
769 AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
770   if (auto simplified = simplifyFloorDiv(*this, other))
771     return simplified;
772 
773   StorageUniquer &uniquer = getContext()->getAffineUniquer();
774   return uniquer.get<AffineBinaryOpExprStorage>(
775       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
776       other);
777 }
778 
779 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
780   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
781   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
782 
783   if (!rhsConst || rhsConst.getValue() < 1)
784     return nullptr;
785 
786   if (lhsConst)
787     return getAffineConstantExpr(
788         ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
789 
790   // Fold ceildiv of a multiply with a constant that is a multiple of the
791   // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
792   if (rhsConst.getValue() == 1)
793     return lhs;
794 
795   // Simplify (expr * const) ceildiv divConst when const is known to be a
796   // multiple of divConst.
797   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
798   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
799     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
800       // rhsConst is known to be a positive constant.
801       if (lrhs.getValue() % rhsConst.getValue() == 0)
802         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
803     }
804   }
805 
806   return nullptr;
807 }
808 
809 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
810   return ceilDiv(getAffineConstantExpr(v, getContext()));
811 }
812 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
813   if (auto simplified = simplifyCeilDiv(*this, other))
814     return simplified;
815 
816   StorageUniquer &uniquer = getContext()->getAffineUniquer();
817   return uniquer.get<AffineBinaryOpExprStorage>(
818       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
819       other);
820 }
821 
822 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
823   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
824   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
825 
826   // mod w.r.t zero or negative numbers is undefined and preserved as is.
827   if (!rhsConst || rhsConst.getValue() < 1)
828     return nullptr;
829 
830   if (lhsConst)
831     return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
832                                  lhs.getContext());
833 
834   // Fold modulo of an expression that is known to be a multiple of a constant
835   // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
836   // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
837   if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
838     return getAffineConstantExpr(0, lhs.getContext());
839 
840   // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
841   // known to be a multiple of divConst.
842   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
843   if (lBin && lBin.getKind() == AffineExprKind::Add) {
844     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
845     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
846     // rhsConst is known to be a positive constant.
847     if (llhsDiv % rhsConst.getValue() == 0)
848       return lBin.getRHS() % rhsConst.getValue();
849     if (lrhsDiv % rhsConst.getValue() == 0)
850       return lBin.getLHS() % rhsConst.getValue();
851   }
852 
853   // Simplify (e % a) % b to e % b when b evenly divides a
854   if (lBin && lBin.getKind() == AffineExprKind::Mod) {
855     auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
856     if (intermediate && intermediate.getValue() >= 1 &&
857         mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
858       return lBin.getLHS() % rhsConst.getValue();
859     }
860   }
861 
862   return nullptr;
863 }
864 
865 AffineExpr AffineExpr::operator%(uint64_t v) const {
866   return *this % getAffineConstantExpr(v, getContext());
867 }
868 AffineExpr AffineExpr::operator%(AffineExpr other) const {
869   if (auto simplified = simplifyMod(*this, other))
870     return simplified;
871 
872   StorageUniquer &uniquer = getContext()->getAffineUniquer();
873   return uniquer.get<AffineBinaryOpExprStorage>(
874       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
875 }
876 
877 AffineExpr AffineExpr::compose(AffineMap map) const {
878   SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
879                                              map.getResults().end());
880   return replaceDimsAndSymbols(dimReplacements, {});
881 }
882 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
883   expr.print(os);
884   return os;
885 }
886 
887 /// Constructs an affine expression from a flat ArrayRef. If there are local
888 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
889 /// products expression, `localExprs` is expected to have the AffineExpr
890 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
891 /// in the format [dims, symbols, locals, constant term].
892 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
893                                            unsigned numDims,
894                                            unsigned numSymbols,
895                                            ArrayRef<AffineExpr> localExprs,
896                                            MLIRContext *context) {
897   // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
898   assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
899          "unexpected number of local expressions");
900 
901   auto expr = getAffineConstantExpr(0, context);
902   // Dimensions and symbols.
903   for (unsigned j = 0; j < numDims + numSymbols; j++) {
904     if (flatExprs[j] == 0)
905       continue;
906     auto id = j < numDims ? getAffineDimExpr(j, context)
907                           : getAffineSymbolExpr(j - numDims, context);
908     expr = expr + id * flatExprs[j];
909   }
910 
911   // Local identifiers.
912   for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
913        j++) {
914     if (flatExprs[j] == 0)
915       continue;
916     auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
917     expr = expr + term;
918   }
919 
920   // Constant term.
921   int64_t constTerm = flatExprs[flatExprs.size() - 1];
922   if (constTerm != 0)
923     expr = expr + constTerm;
924   return expr;
925 }
926 
927 /// Constructs a semi-affine expression from a flat ArrayRef. If there are
928 /// local identifiers (neither dimensional nor symbolic) that appear in the sum
929 /// of products expression, `localExprs` is expected to have the AffineExprs for
930 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
931 /// the format [dims, symbols, locals, constant term]. The semi-affine
932 /// expression is constructed in the sorted order of dimension and symbol
933 /// position numbers. Note:  local expressions/ids are used for mod, div as well
934 /// as symbolic RHS terms for terms that are not pure affine.
935 static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
936                                                 unsigned numDims,
937                                                 unsigned numSymbols,
938                                                 ArrayRef<AffineExpr> localExprs,
939                                                 MLIRContext *context) {
940   assert(!flatExprs.empty() && "flatExprs cannot be empty");
941 
942   // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
943   assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
944          "unexpected number of local expressions");
945 
946   AffineExpr expr = getAffineConstantExpr(0, context);
947 
948   // We design indices as a pair which help us present the semi-affine map as
949   // sum of product where terms are sorted based on dimension or symbol
950   // position: <keyA, keyB> for expressions of the form dimension * symbol,
951   // where keyA is the position number of the dimension and keyB is the
952   // position number of the symbol. For dimensional expressions we set the index
953   // as (position number of the dimension, -1), as we want dimensional
954   // expressions to appear before symbolic and product of dimensional and
955   // symbolic expressions having the dimension with the same position number.
956   // For symbolic expression set the index as (position number of the symbol,
957   // maximum of last dimension and symbol position) number. For example, we want
958   // the expression we are constructing to look something like: d0 + d0 * s0 +
959   // s0 + d1*s1 + s1.
960 
961   // Stores the affine expression corresponding to a given index.
962   DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap;
963   // Stores the constant coefficient value corresponding to a given
964   // dimension, symbol or a non-pure affine expression stored in `localExprs`.
965   DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
966   // Stores the indices as defined above, and later sorted to produce
967   // the semi-affine expression in the desired form.
968   SmallVector<std::pair<unsigned, signed>, 8> indices;
969 
970   // Example: expression = d0 + d0 * s0 + 2 * s0.
971   // indices = [{0,-1}, {0, 0}, {0, 1}]
972   // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
973   // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
974 
975   // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
976   auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
977                       AffineExpr expr) {
978     assert(std::find(indices.begin(), indices.end(), index) == indices.end() &&
979            "Key is already present in indices vector and overwriting will "
980            "happen in `indexToExprMap` and `coefficients`!");
981 
982     indices.push_back(index);
983     coefficients.insert({index, coefficient});
984     indexToExprMap.insert({index, expr});
985   };
986 
987   // Design indices for dimensional or symbolic terms, and store the indices,
988   // constant coefficient corresponding to the indices in `coefficients` map,
989   // and affine expression corresponding to indices in `indexToExprMap` map.
990 
991   for (unsigned j = 0; j < numDims; ++j) {
992     if (flatExprs[j] == 0)
993       continue;
994     // For dimensional expressions we set the index as <position number of the
995     // dimension, 0>, as we want dimensional expressions to appear before
996     // symbolic ones and products of dimensional and symbolic expressions
997     // having the dimension with the same position number.
998     std::pair<unsigned, signed> indexEntry(j, -1);
999     addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
1000   }
1001   for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1002     if (flatExprs[j] == 0)
1003       continue;
1004     // For symbolic expression set the index as <position number
1005     // of the symbol, max(dimCount, symCount)> number,
1006     // as we want symbolic expressions with the same positional number to
1007     // appear after dimensional expressions having the same positional number.
1008     std::pair<unsigned, signed> indexEntry(j - numDims,
1009                                            std::max(numDims, numSymbols));
1010     addEntry(indexEntry, flatExprs[j],
1011              getAffineSymbolExpr(j - numDims, context));
1012   }
1013 
1014   // Denotes semi-affine product, modulo or division terms, which has been added
1015   // to the `indexToExpr` map.
1016   SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1017                                   false);
1018   unsigned lhsPos, rhsPos;
1019   // Construct indices for product terms involving dimension, symbol or constant
1020   // as lhs/rhs, and store the indices, constant coefficient corresponding to
1021   // the indices in `coefficients` map, and affine expression corresponding to
1022   // in indices in `indexToExprMap` map.
1023   for (const auto &it : llvm::enumerate(localExprs)) {
1024     AffineExpr expr = it.value();
1025     if (flatExprs[numDims + numSymbols + it.index()] == 0)
1026       continue;
1027     AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS();
1028     AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS();
1029     if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) &&
1030           (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() ||
1031            rhs.isa<AffineConstantExpr>()))) {
1032       continue;
1033     }
1034     if (rhs.isa<AffineConstantExpr>()) {
1035       // For product/modulo/division expressions, when rhs of modulo/division
1036       // expression is constant, we put 0 in place of keyB, because we want
1037       // them to appear earlier in the semi-affine expression we are
1038       // constructing. When rhs is constant, we place 0 in place of keyB.
1039       if (lhs.isa<AffineDimExpr>()) {
1040         lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1041         std::pair<unsigned, signed> indexEntry(lhsPos, -1);
1042         addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1043                  expr);
1044       } else {
1045         lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1046         std::pair<unsigned, signed> indexEntry(lhsPos,
1047                                                std::max(numDims, numSymbols));
1048         addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1049                  expr);
1050       }
1051     } else if (lhs.isa<AffineDimExpr>()) {
1052       // For product/modulo/division expressions having lhs as dimension and rhs
1053       // as symbol, we order the terms in the semi-affine expression based on
1054       // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1055       // where keyA is the position number of the dimension and keyB is the
1056       // position number of the symbol.
1057       lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1058       rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1059       std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1060       addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1061     } else {
1062       // For product/modulo/division expressions having both lhs and rhs as
1063       // symbol, we design indices as a pair: <keyA, keyB> for expressions
1064       // of the form dimension * symbol, where keyA is the position number of
1065       // the dimension and keyB is the position number of the symbol.
1066       lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1067       rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1068       std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1069       addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1070     }
1071     addedToMap[it.index()] = true;
1072   }
1073 
1074   // Constructing the simplified semi-affine sum of product/division/mod
1075   // expression from the flattened form in the desired sorted order of indices
1076   // of the various individual product/division/mod expressions.
1077   std::sort(indices.begin(), indices.end());
1078   for (const std::pair<unsigned, unsigned> index : indices) {
1079     assert(indexToExprMap.lookup(index) &&
1080            "cannot find key in `indexToExprMap` map");
1081     expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1082   }
1083 
1084   // Local identifiers.
1085   for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1086        j++) {
1087     // If the coefficient of the local expression is 0, continue as we need not
1088     // add it in out final expression.
1089     if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1090       continue;
1091     auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1092     expr = expr + term;
1093   }
1094 
1095   // Constant term.
1096   int64_t constTerm = flatExprs.back();
1097   if (constTerm != 0)
1098     expr = expr + constTerm;
1099   return expr;
1100 }
1101 
1102 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
1103                                                      unsigned numSymbols)
1104     : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1105   operandExprStack.reserve(8);
1106 }
1107 
1108 // In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1109 //
1110 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1111 // introduce a local variable p (= expr * symbolic_expr), and the affine
1112 // expression expr * symbolic_expr is added to `localExprs`.
1113 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
1114   assert(operandExprStack.size() >= 2);
1115   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1116   operandExprStack.pop_back();
1117   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1118 
1119   // Flatten semi-affine multiplication expressions by introducing a local
1120   // variable in place of the product; the affine expression
1121   // corresponding to the quantifier is added to `localExprs`.
1122   if (!expr.getRHS().isa<AffineConstantExpr>()) {
1123     MLIRContext *context = expr.getContext();
1124     AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
1125                                              localExprs, context);
1126     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1127                                              localExprs, context);
1128     addLocalVariableSemiAffine(a * b, lhs, lhs.size());
1129     return;
1130   }
1131 
1132   // Get the RHS constant.
1133   auto rhsConst = rhs[getConstantIndex()];
1134   for (unsigned i = 0, e = lhs.size(); i < e; i++) {
1135     lhs[i] *= rhsConst;
1136   }
1137 }
1138 
1139 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
1140   assert(operandExprStack.size() >= 2);
1141   const auto &rhs = operandExprStack.back();
1142   auto &lhs = operandExprStack[operandExprStack.size() - 2];
1143   assert(lhs.size() == rhs.size());
1144   // Update the LHS in place.
1145   for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1146     lhs[i] += rhs[i];
1147   }
1148   // Pop off the RHS.
1149   operandExprStack.pop_back();
1150 }
1151 
1152 //
1153 // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
1154 //
1155 // A mod expression "expr mod c" is thus flattened by introducing a new local
1156 // variable q (= expr floordiv c), such that expr mod c is replaced with
1157 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1158 //
1159 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1160 // introduce a local variable m (= expr mod symbolic_expr), and the affine
1161 // expression expr mod symbolic_expr is added to `localExprs`.
1162 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1163   assert(operandExprStack.size() >= 2);
1164 
1165   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1166   operandExprStack.pop_back();
1167   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1168   MLIRContext *context = expr.getContext();
1169 
1170   // Flatten semi affine modulo expressions by introducing a local
1171   // variable in place of the modulo value, and the affine expression
1172   // corresponding to the quantifier is added to `localExprs`.
1173   if (!expr.getRHS().isa<AffineConstantExpr>()) {
1174     AffineExpr dividendExpr = getAffineExprFromFlatForm(
1175         lhs, numDims, numSymbols, localExprs, context);
1176     AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1177                                                        localExprs, context);
1178     AffineExpr modExpr = dividendExpr % divisorExpr;
1179     addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
1180     return;
1181   }
1182 
1183   int64_t rhsConst = rhs[getConstantIndex()];
1184   // TODO: handle modulo by zero case when this issue is fixed
1185   // at the other places in the IR.
1186   assert(rhsConst > 0 && "RHS constant has to be positive");
1187 
1188   // Check if the LHS expression is a multiple of modulo factor.
1189   unsigned i, e;
1190   for (i = 0, e = lhs.size(); i < e; i++)
1191     if (lhs[i] % rhsConst != 0)
1192       break;
1193   // If yes, modulo expression here simplifies to zero.
1194   if (i == lhs.size()) {
1195     std::fill(lhs.begin(), lhs.end(), 0);
1196     return;
1197   }
1198 
1199   // Add a local variable for the quotient, i.e., expr % c is replaced by
1200   // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1201   // the GCD of expr and c.
1202   SmallVector<int64_t, 8> floorDividend(lhs);
1203   uint64_t gcd = rhsConst;
1204   for (unsigned i = 0, e = lhs.size(); i < e; i++)
1205     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1206   // Simplify the numerator and the denominator.
1207   if (gcd != 1) {
1208     for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
1209       floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
1210   }
1211   int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1212 
1213   // Construct the AffineExpr form of the floordiv to store in localExprs.
1214 
1215   AffineExpr dividendExpr = getAffineExprFromFlatForm(
1216       floorDividend, numDims, numSymbols, localExprs, context);
1217   AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1218   AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1219   int loc;
1220   if ((loc = findLocalId(floorDivExpr)) == -1) {
1221     addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1222     // Set result at top of stack to "lhs - rhsConst * q".
1223     lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1224   } else {
1225     // Reuse the existing local id.
1226     lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1227   }
1228 }
1229 
1230 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
1231   visitDivExpr(expr, /*isCeil=*/true);
1232 }
1233 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
1234   visitDivExpr(expr, /*isCeil=*/false);
1235 }
1236 
1237 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
1238   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1239   auto &eq = operandExprStack.back();
1240   assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1241   eq[getDimStartIndex() + expr.getPosition()] = 1;
1242 }
1243 
1244 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
1245   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1246   auto &eq = operandExprStack.back();
1247   assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1248   eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1249 }
1250 
1251 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
1252   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1253   auto &eq = operandExprStack.back();
1254   eq[getConstantIndex()] = expr.getValue();
1255 }
1256 
1257 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1258     AffineExpr expr, SmallVectorImpl<int64_t> &result,
1259     unsigned long resultSize) {
1260   assert(result.size() == resultSize &&
1261          "`result` vector passed is not of correct size");
1262   int loc;
1263   if ((loc = findLocalId(expr)) == -1)
1264     addLocalIdSemiAffine(expr);
1265   std::fill(result.begin(), result.end(), 0);
1266   if (loc == -1)
1267     result[getLocalVarStartIndex() + numLocals - 1] = 1;
1268   else
1269     result[getLocalVarStartIndex() + loc] = 1;
1270 }
1271 
1272 // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
1273 // A floordiv is thus flattened by introducing a new local variable q, and
1274 // replacing that expression with 'q' while adding the constraints
1275 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1276 // FlatAffineConstraints::addLocalFloorDiv).
1277 //
1278 // A ceildiv is similarly flattened:
1279 // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
1280 //
1281 // In case of semi affine division expressions, t = expr floordiv symbolic_expr
1282 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1283 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1284 // `localExprs`.
1285 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1286                                              bool isCeil) {
1287   assert(operandExprStack.size() >= 2);
1288 
1289   MLIRContext *context = expr.getContext();
1290   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1291   operandExprStack.pop_back();
1292   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1293 
1294   // Flatten semi affine division expressions by introducing a local
1295   // variable in place of the quotient, and the affine expression corresponding
1296   // to the quantifier is added to `localExprs`.
1297   if (!expr.getRHS().isa<AffineConstantExpr>()) {
1298     AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
1299                                              localExprs, context);
1300     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1301                                              localExprs, context);
1302     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1303     addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
1304     return;
1305   }
1306 
1307   // This is a pure affine expr; the RHS is a positive constant.
1308   int64_t rhsConst = rhs[getConstantIndex()];
1309   // TODO: handle division by zero at the same time the issue is
1310   // fixed at other places.
1311   assert(rhsConst > 0 && "RHS constant has to be positive");
1312 
1313   // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1314   // common divisors of the numerator and denominator.
1315   uint64_t gcd = std::abs(rhsConst);
1316   for (unsigned i = 0, e = lhs.size(); i < e; i++)
1317     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1318   // Simplify the numerator and the denominator.
1319   if (gcd != 1) {
1320     for (unsigned i = 0, e = lhs.size(); i < e; i++)
1321       lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1322   }
1323   int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1324   // If the divisor becomes 1, the updated LHS is the result. (The
1325   // divisor can't be negative since rhsConst is positive).
1326   if (divisor == 1)
1327     return;
1328 
1329   // If the divisor cannot be simplified to one, we will have to retain
1330   // the ceil/floor expr (simplified up until here). Add an existential
1331   // quantifier to express its result, i.e., expr1 div expr2 is replaced
1332   // by a new identifier, q.
1333   AffineExpr a =
1334       getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context);
1335   AffineExpr b = getAffineConstantExpr(divisor, context);
1336 
1337   int loc;
1338   AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1339   if ((loc = findLocalId(divExpr)) == -1) {
1340     if (!isCeil) {
1341       SmallVector<int64_t, 8> dividend(lhs);
1342       addLocalFloorDivId(dividend, divisor, divExpr);
1343     } else {
1344       // lhs ceildiv c <=>  (lhs + c - 1) floordiv c
1345       SmallVector<int64_t, 8> dividend(lhs);
1346       dividend.back() += divisor - 1;
1347       addLocalFloorDivId(dividend, divisor, divExpr);
1348     }
1349   }
1350   // Set the expression on stack to the local var introduced to capture the
1351   // result of the division (floor or ceil).
1352   std::fill(lhs.begin(), lhs.end(), 0);
1353   if (loc == -1)
1354     lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1355   else
1356     lhs[getLocalVarStartIndex() + loc] = 1;
1357 }
1358 
1359 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1360 // The local identifier added is always a floordiv of a pure add/mul affine
1361 // function of other identifiers, coefficients of which are specified in
1362 // dividend and with respect to a positive constant divisor. localExpr is the
1363 // simplified tree expression (AffineExpr) corresponding to the quantifier.
1364 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1365                                                    int64_t divisor,
1366                                                    AffineExpr localExpr) {
1367   assert(divisor > 0 && "positive constant divisor expected");
1368   for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1369     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1370   localExprs.push_back(localExpr);
1371   numLocals++;
1372   // dividend and divisor are not used here; an override of this method uses it.
1373 }
1374 
1375 void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) {
1376   for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1377     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1378   localExprs.push_back(localExpr);
1379   ++numLocals;
1380 }
1381 
1382 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1383   SmallVectorImpl<AffineExpr>::iterator it;
1384   if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1385     return -1;
1386   return it - localExprs.begin();
1387 }
1388 
1389 /// Simplify the affine expression by flattening it and reconstructing it.
1390 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1391                                     unsigned numSymbols) {
1392   // Simplify semi-affine expressions separately.
1393   if (!expr.isPureAffine())
1394     expr = simplifySemiAffine(expr);
1395 
1396   SimpleAffineExprFlattener flattener(numDims, numSymbols);
1397   flattener.walkPostOrder(expr);
1398   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1399   if (!expr.isPureAffine() &&
1400       expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1401                                         flattener.localExprs,
1402                                         expr.getContext()))
1403     return expr;
1404   AffineExpr simplifiedExpr =
1405       expr.isPureAffine()
1406           ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1407                                       flattener.localExprs, expr.getContext())
1408           : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1409                                           flattener.localExprs,
1410                                           expr.getContext());
1411 
1412   flattener.operandExprStack.pop_back();
1413   assert(flattener.operandExprStack.empty());
1414   return simplifiedExpr;
1415 }
1416