1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include <utility>
10 
11 #include "AffineExprDetail.h"
12 #include "mlir/IR/AffineExpr.h"
13 #include "mlir/IR/AffineExprVisitor.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/Support/MathExtras.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/STLExtras.h"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
getContext() const23 MLIRContext *AffineExpr::getContext() const { return expr->context; }
24 
getKind() const25 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
26 
27 /// Walk all of the AffineExprs in this subgraph in postorder.
walk(std::function<void (AffineExpr)> callback) const28 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
29   struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
30     std::function<void(AffineExpr)> callback;
31 
32     AffineExprWalker(std::function<void(AffineExpr)> callback)
33         : callback(std::move(callback)) {}
34 
35     void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
36     void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
37     void visitDimExpr(AffineDimExpr expr) { callback(expr); }
38     void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
39   };
40 
41   AffineExprWalker(std::move(callback)).walkPostOrder(*this);
42 }
43 
44 // Dispatch affine expression construction based on kind.
getAffineBinaryOpExpr(AffineExprKind kind,AffineExpr lhs,AffineExpr rhs)45 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
46                                        AffineExpr rhs) {
47   if (kind == AffineExprKind::Add)
48     return lhs + rhs;
49   if (kind == AffineExprKind::Mul)
50     return lhs * rhs;
51   if (kind == AffineExprKind::FloorDiv)
52     return lhs.floorDiv(rhs);
53   if (kind == AffineExprKind::CeilDiv)
54     return lhs.ceilDiv(rhs);
55   if (kind == AffineExprKind::Mod)
56     return lhs % rhs;
57 
58   llvm_unreachable("unknown binary operation on affine expressions");
59 }
60 
61 /// This method substitutes any uses of dimensions and symbols (e.g.
62 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
63 AffineExpr
replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,ArrayRef<AffineExpr> symReplacements) const64 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
65                                   ArrayRef<AffineExpr> symReplacements) const {
66   switch (getKind()) {
67   case AffineExprKind::Constant:
68     return *this;
69   case AffineExprKind::DimId: {
70     unsigned dimId = cast<AffineDimExpr>().getPosition();
71     if (dimId >= dimReplacements.size())
72       return *this;
73     return dimReplacements[dimId];
74   }
75   case AffineExprKind::SymbolId: {
76     unsigned symId = cast<AffineSymbolExpr>().getPosition();
77     if (symId >= symReplacements.size())
78       return *this;
79     return symReplacements[symId];
80   }
81   case AffineExprKind::Add:
82   case AffineExprKind::Mul:
83   case AffineExprKind::FloorDiv:
84   case AffineExprKind::CeilDiv:
85   case AffineExprKind::Mod:
86     auto binOp = cast<AffineBinaryOpExpr>();
87     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
88     auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
89     auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
90     if (newLHS == lhs && newRHS == rhs)
91       return *this;
92     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
93   }
94   llvm_unreachable("Unknown AffineExpr");
95 }
96 
replaceDims(ArrayRef<AffineExpr> dimReplacements) const97 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
98   return replaceDimsAndSymbols(dimReplacements, {});
99 }
100 
101 AffineExpr
replaceSymbols(ArrayRef<AffineExpr> symReplacements) const102 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
103   return replaceDimsAndSymbols({}, symReplacements);
104 }
105 
106 /// Replace dims[offset ... numDims)
107 /// by dims[offset + shift ... shift + numDims).
shiftDims(unsigned numDims,unsigned shift,unsigned offset) const108 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
109                                  unsigned offset) const {
110   SmallVector<AffineExpr, 4> dims;
111   for (unsigned idx = 0; idx < offset; ++idx)
112     dims.push_back(getAffineDimExpr(idx, getContext()));
113   for (unsigned idx = offset; idx < numDims; ++idx)
114     dims.push_back(getAffineDimExpr(idx + shift, getContext()));
115   return replaceDimsAndSymbols(dims, {});
116 }
117 
118 /// Replace symbols[offset ... numSymbols)
119 /// by symbols[offset + shift ... shift + numSymbols).
shiftSymbols(unsigned numSymbols,unsigned shift,unsigned offset) const120 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
121                                     unsigned offset) const {
122   SmallVector<AffineExpr, 4> symbols;
123   for (unsigned idx = 0; idx < offset; ++idx)
124     symbols.push_back(getAffineSymbolExpr(idx, getContext()));
125   for (unsigned idx = offset; idx < numSymbols; ++idx)
126     symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
127   return replaceDimsAndSymbols({}, symbols);
128 }
129 
130 /// Sparse replace method. Return the modified expression tree.
131 AffineExpr
replace(const DenseMap<AffineExpr,AffineExpr> & map) const132 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
133   auto it = map.find(*this);
134   if (it != map.end())
135     return it->second;
136   switch (getKind()) {
137   default:
138     return *this;
139   case AffineExprKind::Add:
140   case AffineExprKind::Mul:
141   case AffineExprKind::FloorDiv:
142   case AffineExprKind::CeilDiv:
143   case AffineExprKind::Mod:
144     auto binOp = cast<AffineBinaryOpExpr>();
145     auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
146     auto newLHS = lhs.replace(map);
147     auto newRHS = rhs.replace(map);
148     if (newLHS == lhs && newRHS == rhs)
149       return *this;
150     return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
151   }
152   llvm_unreachable("Unknown AffineExpr");
153 }
154 
155 /// Sparse replace method. Return the modified expression tree.
replace(AffineExpr expr,AffineExpr replacement) const156 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
157   DenseMap<AffineExpr, AffineExpr> map;
158   map.insert(std::make_pair(expr, replacement));
159   return replace(map);
160 }
161 /// Returns true if this expression is made out of only symbols and
162 /// constants (no dimensional identifiers).
isSymbolicOrConstant() const163 bool AffineExpr::isSymbolicOrConstant() const {
164   switch (getKind()) {
165   case AffineExprKind::Constant:
166     return true;
167   case AffineExprKind::DimId:
168     return false;
169   case AffineExprKind::SymbolId:
170     return true;
171 
172   case AffineExprKind::Add:
173   case AffineExprKind::Mul:
174   case AffineExprKind::FloorDiv:
175   case AffineExprKind::CeilDiv:
176   case AffineExprKind::Mod: {
177     auto expr = this->cast<AffineBinaryOpExpr>();
178     return expr.getLHS().isSymbolicOrConstant() &&
179            expr.getRHS().isSymbolicOrConstant();
180   }
181   }
182   llvm_unreachable("Unknown AffineExpr");
183 }
184 
185 /// Returns true if this is a pure affine expression, i.e., multiplication,
186 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
isPureAffine() const187 bool AffineExpr::isPureAffine() const {
188   switch (getKind()) {
189   case AffineExprKind::SymbolId:
190   case AffineExprKind::DimId:
191   case AffineExprKind::Constant:
192     return true;
193   case AffineExprKind::Add: {
194     auto op = cast<AffineBinaryOpExpr>();
195     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
196   }
197 
198   case AffineExprKind::Mul: {
199     // TODO: Canonicalize the constants in binary operators to the RHS when
200     // possible, allowing this to merge into the next case.
201     auto op = cast<AffineBinaryOpExpr>();
202     return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
203            (op.getLHS().template isa<AffineConstantExpr>() ||
204             op.getRHS().template isa<AffineConstantExpr>());
205   }
206   case AffineExprKind::FloorDiv:
207   case AffineExprKind::CeilDiv:
208   case AffineExprKind::Mod: {
209     auto op = cast<AffineBinaryOpExpr>();
210     return op.getLHS().isPureAffine() &&
211            op.getRHS().template isa<AffineConstantExpr>();
212   }
213   }
214   llvm_unreachable("Unknown AffineExpr");
215 }
216 
217 // Returns the greatest known integral divisor of this affine expression.
getLargestKnownDivisor() const218 int64_t AffineExpr::getLargestKnownDivisor() const {
219   AffineBinaryOpExpr binExpr(nullptr);
220   switch (getKind()) {
221   case AffineExprKind::CeilDiv:
222     LLVM_FALLTHROUGH;
223   case AffineExprKind::DimId:
224   case AffineExprKind::FloorDiv:
225   case AffineExprKind::SymbolId:
226     return 1;
227   case AffineExprKind::Constant:
228     return std::abs(this->cast<AffineConstantExpr>().getValue());
229   case AffineExprKind::Mul: {
230     binExpr = this->cast<AffineBinaryOpExpr>();
231     return binExpr.getLHS().getLargestKnownDivisor() *
232            binExpr.getRHS().getLargestKnownDivisor();
233   }
234   case AffineExprKind::Add:
235     LLVM_FALLTHROUGH;
236   case AffineExprKind::Mod: {
237     binExpr = cast<AffineBinaryOpExpr>();
238     return llvm::GreatestCommonDivisor64(
239         binExpr.getLHS().getLargestKnownDivisor(),
240         binExpr.getRHS().getLargestKnownDivisor());
241   }
242   }
243   llvm_unreachable("Unknown AffineExpr");
244 }
245 
isMultipleOf(int64_t factor) const246 bool AffineExpr::isMultipleOf(int64_t factor) const {
247   AffineBinaryOpExpr binExpr(nullptr);
248   uint64_t l, u;
249   switch (getKind()) {
250   case AffineExprKind::SymbolId:
251     LLVM_FALLTHROUGH;
252   case AffineExprKind::DimId:
253     return factor * factor == 1;
254   case AffineExprKind::Constant:
255     return cast<AffineConstantExpr>().getValue() % factor == 0;
256   case AffineExprKind::Mul: {
257     binExpr = cast<AffineBinaryOpExpr>();
258     // It's probably not worth optimizing this further (to not traverse the
259     // whole sub-tree under - it that would require a version of isMultipleOf
260     // that on a 'false' return also returns the largest known divisor).
261     return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
262            (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
263            (l * u) % factor == 0;
264   }
265   case AffineExprKind::Add:
266   case AffineExprKind::FloorDiv:
267   case AffineExprKind::CeilDiv:
268   case AffineExprKind::Mod: {
269     binExpr = cast<AffineBinaryOpExpr>();
270     return llvm::GreatestCommonDivisor64(
271                binExpr.getLHS().getLargestKnownDivisor(),
272                binExpr.getRHS().getLargestKnownDivisor()) %
273                factor ==
274            0;
275   }
276   }
277   llvm_unreachable("Unknown AffineExpr");
278 }
279 
isFunctionOfDim(unsigned position) const280 bool AffineExpr::isFunctionOfDim(unsigned position) const {
281   if (getKind() == AffineExprKind::DimId) {
282     return *this == mlir::getAffineDimExpr(position, getContext());
283   }
284   if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
285     return expr.getLHS().isFunctionOfDim(position) ||
286            expr.getRHS().isFunctionOfDim(position);
287   }
288   return false;
289 }
290 
isFunctionOfSymbol(unsigned position) const291 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
292   if (getKind() == AffineExprKind::SymbolId) {
293     return *this == mlir::getAffineSymbolExpr(position, getContext());
294   }
295   if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
296     return expr.getLHS().isFunctionOfSymbol(position) ||
297            expr.getRHS().isFunctionOfSymbol(position);
298   }
299   return false;
300 }
301 
AffineBinaryOpExpr(AffineExpr::ImplType * ptr)302 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
303     : AffineExpr(ptr) {}
getLHS() const304 AffineExpr AffineBinaryOpExpr::getLHS() const {
305   return static_cast<ImplType *>(expr)->lhs;
306 }
getRHS() const307 AffineExpr AffineBinaryOpExpr::getRHS() const {
308   return static_cast<ImplType *>(expr)->rhs;
309 }
310 
AffineDimExpr(AffineExpr::ImplType * ptr)311 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
getPosition() const312 unsigned AffineDimExpr::getPosition() const {
313   return static_cast<ImplType *>(expr)->position;
314 }
315 
316 /// Returns true if the expression is divisible by the given symbol with
317 /// position `symbolPos`. The argument `opKind` specifies here what kind of
318 /// division or mod operation called this division. It helps in implementing the
319 /// commutative property of the floordiv and ceildiv operations. If the argument
320 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
321 /// operation, then the commutative property can be used otherwise, the floordiv
322 /// operation is not divisible. The same argument holds for ceildiv operation.
isDivisibleBySymbol(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)323 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
324                                 AffineExprKind opKind) {
325   // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
326   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
327           opKind == AffineExprKind::CeilDiv) &&
328          "unexpected opKind");
329   switch (expr.getKind()) {
330   case AffineExprKind::Constant:
331     return expr.cast<AffineConstantExpr>().getValue() == 0;
332   case AffineExprKind::DimId:
333     return false;
334   case AffineExprKind::SymbolId:
335     return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
336   // Checks divisibility by the given symbol for both operands.
337   case AffineExprKind::Add: {
338     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
339     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
340            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
341   }
342   // Checks divisibility by the given symbol for both operands. Consider the
343   // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
344   // this is a division by s1 and both the operands of modulo are divisible by
345   // s1 but it is not divisible by s1 always. The third argument is
346   // `AffineExprKind::Mod` for this reason.
347   case AffineExprKind::Mod: {
348     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
349     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
350                                AffineExprKind::Mod) &&
351            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
352                                AffineExprKind::Mod);
353   }
354   // Checks if any of the operand divisible by the given symbol.
355   case AffineExprKind::Mul: {
356     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
357     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
358            isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
359   }
360   // Floordiv and ceildiv are divisible by the given symbol when the first
361   // operand is divisible, and the affine expression kind of the argument expr
362   // is same as the argument `opKind`. This can be inferred from commutative
363   // property of floordiv and ceildiv operations and are as follow:
364   // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
365   // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
366   // It will fail if operations are not same. For example:
367   // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
368   case AffineExprKind::FloorDiv:
369   case AffineExprKind::CeilDiv: {
370     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
371     if (opKind != expr.getKind())
372       return false;
373     return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
374   }
375   }
376   llvm_unreachable("Unknown AffineExpr");
377 }
378 
379 /// Divides the given expression by the given symbol at position `symbolPos`. It
380 /// considers the divisibility condition is checked before calling itself. A
381 /// null expression is returned whenever the divisibility condition fails.
symbolicDivide(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)382 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
383                                  AffineExprKind opKind) {
384   // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
385   assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
386           opKind == AffineExprKind::CeilDiv) &&
387          "unexpected opKind");
388   switch (expr.getKind()) {
389   case AffineExprKind::Constant:
390     if (expr.cast<AffineConstantExpr>().getValue() != 0)
391       return nullptr;
392     return getAffineConstantExpr(0, expr.getContext());
393   case AffineExprKind::DimId:
394     return nullptr;
395   case AffineExprKind::SymbolId:
396     return getAffineConstantExpr(1, expr.getContext());
397   // Dividing both operands by the given symbol.
398   case AffineExprKind::Add: {
399     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
400     return getAffineBinaryOpExpr(
401         expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
402         symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
403   }
404   // Dividing both operands by the given symbol.
405   case AffineExprKind::Mod: {
406     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
407     return getAffineBinaryOpExpr(
408         expr.getKind(),
409         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
410         symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
411   }
412   // Dividing any of the operand by the given symbol.
413   case AffineExprKind::Mul: {
414     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
415     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
416       return binaryExpr.getLHS() *
417              symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
418     return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
419            binaryExpr.getRHS();
420   }
421   // Dividing first operand only by the given symbol.
422   case AffineExprKind::FloorDiv:
423   case AffineExprKind::CeilDiv: {
424     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
425     return getAffineBinaryOpExpr(
426         expr.getKind(),
427         symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
428         binaryExpr.getRHS());
429   }
430   }
431   llvm_unreachable("Unknown AffineExpr");
432 }
433 
434 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
435 /// operations when the second operand simplifies to a symbol and the first
436 /// operand is divisible by that symbol. It can be applied to any semi-affine
437 /// expression. Returned expression can either be a semi-affine or pure affine
438 /// expression.
simplifySemiAffine(AffineExpr expr)439 static AffineExpr simplifySemiAffine(AffineExpr expr) {
440   switch (expr.getKind()) {
441   case AffineExprKind::Constant:
442   case AffineExprKind::DimId:
443   case AffineExprKind::SymbolId:
444     return expr;
445   case AffineExprKind::Add:
446   case AffineExprKind::Mul: {
447     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
448     return getAffineBinaryOpExpr(expr.getKind(),
449                                  simplifySemiAffine(binaryExpr.getLHS()),
450                                  simplifySemiAffine(binaryExpr.getRHS()));
451   }
452   // Check if the simplification of the second operand is a symbol, and the
453   // first operand is divisible by it. If the operation is a modulo, a constant
454   // zero expression is returned. In the case of floordiv and ceildiv, the
455   // symbol from the simplification of the second operand divides the first
456   // operand. Otherwise, simplification is not possible.
457   case AffineExprKind::FloorDiv:
458   case AffineExprKind::CeilDiv:
459   case AffineExprKind::Mod: {
460     AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
461     AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
462     AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
463     AffineSymbolExpr symbolExpr =
464         simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
465     if (!symbolExpr)
466       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
467     unsigned symbolPos = symbolExpr.getPosition();
468     if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
469       return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
470     if (expr.getKind() == AffineExprKind::Mod)
471       return getAffineConstantExpr(0, expr.getContext());
472     return symbolicDivide(sLHS, symbolPos, expr.getKind());
473   }
474   }
475   llvm_unreachable("Unknown AffineExpr");
476 }
477 
getAffineDimOrSymbol(AffineExprKind kind,unsigned position,MLIRContext * context)478 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
479                                        MLIRContext *context) {
480   auto assignCtx = [context](AffineDimExprStorage *storage) {
481     storage->context = context;
482   };
483 
484   StorageUniquer &uniquer = context->getAffineUniquer();
485   return uniquer.get<AffineDimExprStorage>(
486       assignCtx, static_cast<unsigned>(kind), position);
487 }
488 
getAffineDimExpr(unsigned position,MLIRContext * context)489 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
490   return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
491 }
492 
AffineSymbolExpr(AffineExpr::ImplType * ptr)493 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
494     : AffineExpr(ptr) {}
getPosition() const495 unsigned AffineSymbolExpr::getPosition() const {
496   return static_cast<ImplType *>(expr)->position;
497 }
498 
getAffineSymbolExpr(unsigned position,MLIRContext * context)499 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
500   return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
501   ;
502 }
503 
AffineConstantExpr(AffineExpr::ImplType * ptr)504 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
505     : AffineExpr(ptr) {}
getValue() const506 int64_t AffineConstantExpr::getValue() const {
507   return static_cast<ImplType *>(expr)->constant;
508 }
509 
operator ==(int64_t v) const510 bool AffineExpr::operator==(int64_t v) const {
511   return *this == getAffineConstantExpr(v, getContext());
512 }
513 
getAffineConstantExpr(int64_t constant,MLIRContext * context)514 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
515   auto assignCtx = [context](AffineConstantExprStorage *storage) {
516     storage->context = context;
517   };
518 
519   StorageUniquer &uniquer = context->getAffineUniquer();
520   return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
521 }
522 
523 /// Simplify add expression. Return nullptr if it can't be simplified.
simplifyAdd(AffineExpr lhs,AffineExpr rhs)524 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
525   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
526   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
527   // Fold if both LHS, RHS are a constant.
528   if (lhsConst && rhsConst)
529     return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
530                                  lhs.getContext());
531 
532   // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
533   // If only one of them is a symbolic expressions, make it the RHS.
534   if (lhs.isa<AffineConstantExpr>() ||
535       (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
536     return rhs + lhs;
537   }
538 
539   // At this point, if there was a constant, it would be on the right.
540 
541   // Addition with a zero is a noop, return the other input.
542   if (rhsConst) {
543     if (rhsConst.getValue() == 0)
544       return lhs;
545   }
546   // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
547   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
548   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
549     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
550       return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
551   }
552 
553   // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
554   // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
555   // respective multiplicands.
556   Optional<int64_t> rLhsConst, rRhsConst;
557   AffineExpr firstExpr, secondExpr;
558   AffineConstantExpr rLhsConstExpr;
559   auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
560   if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
561       (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
562     rLhsConst = rLhsConstExpr.getValue();
563     firstExpr = lBinOpExpr.getLHS();
564   } else {
565     rLhsConst = 1;
566     firstExpr = lhs;
567   }
568 
569   auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
570   AffineConstantExpr rRhsConstExpr;
571   if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
572       (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
573     rRhsConst = rRhsConstExpr.getValue();
574     secondExpr = rBinOpExpr.getLHS();
575   } else {
576     rRhsConst = 1;
577     secondExpr = rhs;
578   }
579 
580   if (rLhsConst && rRhsConst && firstExpr == secondExpr)
581     return getAffineBinaryOpExpr(
582         AffineExprKind::Mul, firstExpr,
583         getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext()));
584 
585   // When doing successive additions, bring constant to the right: turn (d0 + 2)
586   // + d1 into (d0 + d1) + 2.
587   if (lBin && lBin.getKind() == AffineExprKind::Add) {
588     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
589       return lBin.getLHS() + rhs + lrhs;
590     }
591   }
592 
593   // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
594   // q may be a constant or symbolic expression. This leads to a much more
595   // efficient form when 'c' is a power of two, and in general a more compact
596   // and readable form.
597 
598   // Process '(expr floordiv c) * (-c)'.
599   if (!rBinOpExpr)
600     return nullptr;
601 
602   auto lrhs = rBinOpExpr.getLHS();
603   auto rrhs = rBinOpExpr.getRHS();
604 
605   AffineExpr llrhs, rlrhs;
606 
607   // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
608   // symbolic expression.
609   auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
610   // Check rrhsConstOpExpr = -1.
611   auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>();
612   if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
613       lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
614     // Check llrhs = expr floordiv q.
615     llrhs = lrhsBinOpExpr.getLHS();
616     // Check rlrhs = q.
617     rlrhs = lrhsBinOpExpr.getRHS();
618     auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>();
619     if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
620       return nullptr;
621     if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
622       return lhs % rlrhs;
623   }
624 
625   // Process lrhs, which is 'expr floordiv c'.
626   AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
627   if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
628     return nullptr;
629 
630   llrhs = lrBinOpExpr.getLHS();
631   rlrhs = lrBinOpExpr.getRHS();
632 
633   if (lhs == llrhs && rlrhs == -rrhs) {
634     return lhs % rlrhs;
635   }
636   return nullptr;
637 }
638 
operator +(int64_t v) const639 AffineExpr AffineExpr::operator+(int64_t v) const {
640   return *this + getAffineConstantExpr(v, getContext());
641 }
operator +(AffineExpr other) const642 AffineExpr AffineExpr::operator+(AffineExpr other) const {
643   if (auto simplified = simplifyAdd(*this, other))
644     return simplified;
645 
646   StorageUniquer &uniquer = getContext()->getAffineUniquer();
647   return uniquer.get<AffineBinaryOpExprStorage>(
648       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
649 }
650 
651 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
simplifyMul(AffineExpr lhs,AffineExpr rhs)652 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
653   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
654   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
655 
656   if (lhsConst && rhsConst)
657     return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
658                                  lhs.getContext());
659 
660   assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
661 
662   // Canonicalize the mul expression so that the constant/symbolic term is the
663   // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
664   // constant. (Note that a constant is trivially symbolic).
665   if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
666     // At least one of them has to be symbolic.
667     return rhs * lhs;
668   }
669 
670   // At this point, if there was a constant, it would be on the right.
671 
672   // Multiplication with a one is a noop, return the other input.
673   if (rhsConst) {
674     if (rhsConst.getValue() == 1)
675       return lhs;
676     // Multiplication with zero.
677     if (rhsConst.getValue() == 0)
678       return rhsConst;
679   }
680 
681   // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
682   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
683   if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
684     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
685       return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
686   }
687 
688   // When doing successive multiplication, bring constant to the right: turn (d0
689   // * 2) * d1 into (d0 * d1) * 2.
690   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
691     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
692       return (lBin.getLHS() * rhs) * lrhs;
693     }
694   }
695 
696   return nullptr;
697 }
698 
operator *(int64_t v) const699 AffineExpr AffineExpr::operator*(int64_t v) const {
700   return *this * getAffineConstantExpr(v, getContext());
701 }
operator *(AffineExpr other) const702 AffineExpr AffineExpr::operator*(AffineExpr other) const {
703   if (auto simplified = simplifyMul(*this, other))
704     return simplified;
705 
706   StorageUniquer &uniquer = getContext()->getAffineUniquer();
707   return uniquer.get<AffineBinaryOpExprStorage>(
708       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
709 }
710 
711 // Unary minus, delegate to operator*.
operator -() const712 AffineExpr AffineExpr::operator-() const {
713   return *this * getAffineConstantExpr(-1, getContext());
714 }
715 
716 // Delegate to operator+.
operator -(int64_t v) const717 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
operator -(AffineExpr other) const718 AffineExpr AffineExpr::operator-(AffineExpr other) const {
719   return *this + (-other);
720 }
721 
simplifyFloorDiv(AffineExpr lhs,AffineExpr rhs)722 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
723   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
724   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
725 
726   // mlir floordiv by zero or negative numbers is undefined and preserved as is.
727   if (!rhsConst || rhsConst.getValue() < 1)
728     return nullptr;
729 
730   if (lhsConst)
731     return getAffineConstantExpr(
732         floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
733 
734   // Fold floordiv of a multiply with a constant that is a multiple of the
735   // divisor. Eg: (i * 128) floordiv 64 = i * 2.
736   if (rhsConst == 1)
737     return lhs;
738 
739   // Simplify (expr * const) floordiv divConst when expr is known to be a
740   // multiple of divConst.
741   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
742   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
743     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
744       // rhsConst is known to be a positive constant.
745       if (lrhs.getValue() % rhsConst.getValue() == 0)
746         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
747     }
748   }
749 
750   // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
751   // known to be a multiple of divConst.
752   if (lBin && lBin.getKind() == AffineExprKind::Add) {
753     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
754     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
755     // rhsConst is known to be a positive constant.
756     if (llhsDiv % rhsConst.getValue() == 0 ||
757         lrhsDiv % rhsConst.getValue() == 0)
758       return lBin.getLHS().floorDiv(rhsConst.getValue()) +
759              lBin.getRHS().floorDiv(rhsConst.getValue());
760   }
761 
762   return nullptr;
763 }
764 
floorDiv(uint64_t v) const765 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
766   return floorDiv(getAffineConstantExpr(v, getContext()));
767 }
floorDiv(AffineExpr other) const768 AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
769   if (auto simplified = simplifyFloorDiv(*this, other))
770     return simplified;
771 
772   StorageUniquer &uniquer = getContext()->getAffineUniquer();
773   return uniquer.get<AffineBinaryOpExprStorage>(
774       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
775       other);
776 }
777 
simplifyCeilDiv(AffineExpr lhs,AffineExpr rhs)778 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
779   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
780   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
781 
782   if (!rhsConst || rhsConst.getValue() < 1)
783     return nullptr;
784 
785   if (lhsConst)
786     return getAffineConstantExpr(
787         ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
788 
789   // Fold ceildiv of a multiply with a constant that is a multiple of the
790   // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
791   if (rhsConst.getValue() == 1)
792     return lhs;
793 
794   // Simplify (expr * const) ceildiv divConst when const is known to be a
795   // multiple of divConst.
796   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
797   if (lBin && lBin.getKind() == AffineExprKind::Mul) {
798     if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
799       // rhsConst is known to be a positive constant.
800       if (lrhs.getValue() % rhsConst.getValue() == 0)
801         return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
802     }
803   }
804 
805   return nullptr;
806 }
807 
ceilDiv(uint64_t v) const808 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
809   return ceilDiv(getAffineConstantExpr(v, getContext()));
810 }
ceilDiv(AffineExpr other) const811 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
812   if (auto simplified = simplifyCeilDiv(*this, other))
813     return simplified;
814 
815   StorageUniquer &uniquer = getContext()->getAffineUniquer();
816   return uniquer.get<AffineBinaryOpExprStorage>(
817       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
818       other);
819 }
820 
simplifyMod(AffineExpr lhs,AffineExpr rhs)821 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
822   auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
823   auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
824 
825   // mod w.r.t zero or negative numbers is undefined and preserved as is.
826   if (!rhsConst || rhsConst.getValue() < 1)
827     return nullptr;
828 
829   if (lhsConst)
830     return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
831                                  lhs.getContext());
832 
833   // Fold modulo of an expression that is known to be a multiple of a constant
834   // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
835   // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
836   if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
837     return getAffineConstantExpr(0, lhs.getContext());
838 
839   // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
840   // known to be a multiple of divConst.
841   auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
842   if (lBin && lBin.getKind() == AffineExprKind::Add) {
843     int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
844     int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
845     // rhsConst is known to be a positive constant.
846     if (llhsDiv % rhsConst.getValue() == 0)
847       return lBin.getRHS() % rhsConst.getValue();
848     if (lrhsDiv % rhsConst.getValue() == 0)
849       return lBin.getLHS() % rhsConst.getValue();
850   }
851 
852   // Simplify (e % a) % b to e % b when b evenly divides a
853   if (lBin && lBin.getKind() == AffineExprKind::Mod) {
854     auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
855     if (intermediate && intermediate.getValue() >= 1 &&
856         mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
857       return lBin.getLHS() % rhsConst.getValue();
858     }
859   }
860 
861   return nullptr;
862 }
863 
operator %(uint64_t v) const864 AffineExpr AffineExpr::operator%(uint64_t v) const {
865   return *this % getAffineConstantExpr(v, getContext());
866 }
operator %(AffineExpr other) const867 AffineExpr AffineExpr::operator%(AffineExpr other) const {
868   if (auto simplified = simplifyMod(*this, other))
869     return simplified;
870 
871   StorageUniquer &uniquer = getContext()->getAffineUniquer();
872   return uniquer.get<AffineBinaryOpExprStorage>(
873       /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
874 }
875 
compose(AffineMap map) const876 AffineExpr AffineExpr::compose(AffineMap map) const {
877   SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
878                                              map.getResults().end());
879   return replaceDimsAndSymbols(dimReplacements, {});
880 }
operator <<(raw_ostream & os,AffineExpr expr)881 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
882   expr.print(os);
883   return os;
884 }
885 
886 /// Constructs an affine expression from a flat ArrayRef. If there are local
887 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
888 /// products expression, `localExprs` is expected to have the AffineExpr
889 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
890 /// in the format [dims, symbols, locals, constant term].
getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,unsigned numDims,unsigned numSymbols,ArrayRef<AffineExpr> localExprs,MLIRContext * context)891 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
892                                            unsigned numDims,
893                                            unsigned numSymbols,
894                                            ArrayRef<AffineExpr> localExprs,
895                                            MLIRContext *context) {
896   // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
897   assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
898          "unexpected number of local expressions");
899 
900   auto expr = getAffineConstantExpr(0, context);
901   // Dimensions and symbols.
902   for (unsigned j = 0; j < numDims + numSymbols; j++) {
903     if (flatExprs[j] == 0)
904       continue;
905     auto id = j < numDims ? getAffineDimExpr(j, context)
906                           : getAffineSymbolExpr(j - numDims, context);
907     expr = expr + id * flatExprs[j];
908   }
909 
910   // Local identifiers.
911   for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
912        j++) {
913     if (flatExprs[j] == 0)
914       continue;
915     auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
916     expr = expr + term;
917   }
918 
919   // Constant term.
920   int64_t constTerm = flatExprs[flatExprs.size() - 1];
921   if (constTerm != 0)
922     expr = expr + constTerm;
923   return expr;
924 }
925 
926 /// Constructs a semi-affine expression from a flat ArrayRef. If there are
927 /// local identifiers (neither dimensional nor symbolic) that appear in the sum
928 /// of products expression, `localExprs` is expected to have the AffineExprs for
929 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
930 /// the format [dims, symbols, locals, constant term]. The semi-affine
931 /// expression is constructed in the sorted order of dimension and symbol
932 /// position numbers. Note:  local expressions/ids are used for mod, div as well
933 /// as symbolic RHS terms for terms that are not pure affine.
getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,unsigned numDims,unsigned numSymbols,ArrayRef<AffineExpr> localExprs,MLIRContext * context)934 static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
935                                                 unsigned numDims,
936                                                 unsigned numSymbols,
937                                                 ArrayRef<AffineExpr> localExprs,
938                                                 MLIRContext *context) {
939   assert(!flatExprs.empty() && "flatExprs cannot be empty");
940 
941   // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
942   assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
943          "unexpected number of local expressions");
944 
945   AffineExpr expr = getAffineConstantExpr(0, context);
946 
947   // We design indices as a pair which help us present the semi-affine map as
948   // sum of product where terms are sorted based on dimension or symbol
949   // position: <keyA, keyB> for expressions of the form dimension * symbol,
950   // where keyA is the position number of the dimension and keyB is the
951   // position number of the symbol. For dimensional expressions we set the index
952   // as (position number of the dimension, -1), as we want dimensional
953   // expressions to appear before symbolic and product of dimensional and
954   // symbolic expressions having the dimension with the same position number.
955   // For symbolic expression set the index as (position number of the symbol,
956   // maximum of last dimension and symbol position) number. For example, we want
957   // the expression we are constructing to look something like: d0 + d0 * s0 +
958   // s0 + d1*s1 + s1.
959 
960   // Stores the affine expression corresponding to a given index.
961   DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap;
962   // Stores the constant coefficient value corresponding to a given
963   // dimension, symbol or a non-pure affine expression stored in `localExprs`.
964   DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
965   // Stores the indices as defined above, and later sorted to produce
966   // the semi-affine expression in the desired form.
967   SmallVector<std::pair<unsigned, signed>, 8> indices;
968 
969   // Example: expression = d0 + d0 * s0 + 2 * s0.
970   // indices = [{0,-1}, {0, 0}, {0, 1}]
971   // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
972   // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
973 
974   // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
975   auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
976                       AffineExpr expr) {
977     assert(!llvm::is_contained(indices, index) &&
978            "Key is already present in indices vector and overwriting will "
979            "happen in `indexToExprMap` and `coefficients`!");
980 
981     indices.push_back(index);
982     coefficients.insert({index, coefficient});
983     indexToExprMap.insert({index, expr});
984   };
985 
986   // Design indices for dimensional or symbolic terms, and store the indices,
987   // constant coefficient corresponding to the indices in `coefficients` map,
988   // and affine expression corresponding to indices in `indexToExprMap` map.
989 
990   for (unsigned j = 0; j < numDims; ++j) {
991     if (flatExprs[j] == 0)
992       continue;
993     // For dimensional expressions we set the index as <position number of the
994     // dimension, 0>, as we want dimensional expressions to appear before
995     // symbolic ones and products of dimensional and symbolic expressions
996     // having the dimension with the same position number.
997     std::pair<unsigned, signed> indexEntry(j, -1);
998     addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
999   }
1000   for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1001     if (flatExprs[j] == 0)
1002       continue;
1003     // For symbolic expression set the index as <position number
1004     // of the symbol, max(dimCount, symCount)> number,
1005     // as we want symbolic expressions with the same positional number to
1006     // appear after dimensional expressions having the same positional number.
1007     std::pair<unsigned, signed> indexEntry(j - numDims,
1008                                            std::max(numDims, numSymbols));
1009     addEntry(indexEntry, flatExprs[j],
1010              getAffineSymbolExpr(j - numDims, context));
1011   }
1012 
1013   // Denotes semi-affine product, modulo or division terms, which has been added
1014   // to the `indexToExpr` map.
1015   SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1016                                   false);
1017   unsigned lhsPos, rhsPos;
1018   // Construct indices for product terms involving dimension, symbol or constant
1019   // as lhs/rhs, and store the indices, constant coefficient corresponding to
1020   // the indices in `coefficients` map, and affine expression corresponding to
1021   // in indices in `indexToExprMap` map.
1022   for (const auto &it : llvm::enumerate(localExprs)) {
1023     AffineExpr expr = it.value();
1024     if (flatExprs[numDims + numSymbols + it.index()] == 0)
1025       continue;
1026     AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS();
1027     AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS();
1028     if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) &&
1029           (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() ||
1030            rhs.isa<AffineConstantExpr>()))) {
1031       continue;
1032     }
1033     if (rhs.isa<AffineConstantExpr>()) {
1034       // For product/modulo/division expressions, when rhs of modulo/division
1035       // expression is constant, we put 0 in place of keyB, because we want
1036       // them to appear earlier in the semi-affine expression we are
1037       // constructing. When rhs is constant, we place 0 in place of keyB.
1038       if (lhs.isa<AffineDimExpr>()) {
1039         lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1040         std::pair<unsigned, signed> indexEntry(lhsPos, -1);
1041         addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1042                  expr);
1043       } else {
1044         lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1045         std::pair<unsigned, signed> indexEntry(lhsPos,
1046                                                std::max(numDims, numSymbols));
1047         addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1048                  expr);
1049       }
1050     } else if (lhs.isa<AffineDimExpr>()) {
1051       // For product/modulo/division expressions having lhs as dimension and rhs
1052       // as symbol, we order the terms in the semi-affine expression based on
1053       // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1054       // where keyA is the position number of the dimension and keyB is the
1055       // position number of the symbol.
1056       lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1057       rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1058       std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1059       addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1060     } else {
1061       // For product/modulo/division expressions having both lhs and rhs as
1062       // symbol, we design indices as a pair: <keyA, keyB> for expressions
1063       // of the form dimension * symbol, where keyA is the position number of
1064       // the dimension and keyB is the position number of the symbol.
1065       lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1066       rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1067       std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1068       addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1069     }
1070     addedToMap[it.index()] = true;
1071   }
1072 
1073   // Constructing the simplified semi-affine sum of product/division/mod
1074   // expression from the flattened form in the desired sorted order of indices
1075   // of the various individual product/division/mod expressions.
1076   llvm::sort(indices);
1077   for (const std::pair<unsigned, unsigned> index : indices) {
1078     assert(indexToExprMap.lookup(index) &&
1079            "cannot find key in `indexToExprMap` map");
1080     expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1081   }
1082 
1083   // Local identifiers.
1084   for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1085        j++) {
1086     // If the coefficient of the local expression is 0, continue as we need not
1087     // add it in out final expression.
1088     if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1089       continue;
1090     auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1091     expr = expr + term;
1092   }
1093 
1094   // Constant term.
1095   int64_t constTerm = flatExprs.back();
1096   if (constTerm != 0)
1097     expr = expr + constTerm;
1098   return expr;
1099 }
1100 
SimpleAffineExprFlattener(unsigned numDims,unsigned numSymbols)1101 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
1102                                                      unsigned numSymbols)
1103     : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1104   operandExprStack.reserve(8);
1105 }
1106 
1107 // In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1108 //
1109 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1110 // introduce a local variable p (= expr * symbolic_expr), and the affine
1111 // expression expr * symbolic_expr is added to `localExprs`.
visitMulExpr(AffineBinaryOpExpr expr)1112 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
1113   assert(operandExprStack.size() >= 2);
1114   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1115   operandExprStack.pop_back();
1116   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1117 
1118   // Flatten semi-affine multiplication expressions by introducing a local
1119   // variable in place of the product; the affine expression
1120   // corresponding to the quantifier is added to `localExprs`.
1121   if (!expr.getRHS().isa<AffineConstantExpr>()) {
1122     MLIRContext *context = expr.getContext();
1123     AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
1124                                              localExprs, context);
1125     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1126                                              localExprs, context);
1127     addLocalVariableSemiAffine(a * b, lhs, lhs.size());
1128     return;
1129   }
1130 
1131   // Get the RHS constant.
1132   auto rhsConst = rhs[getConstantIndex()];
1133   for (unsigned i = 0, e = lhs.size(); i < e; i++) {
1134     lhs[i] *= rhsConst;
1135   }
1136 }
1137 
visitAddExpr(AffineBinaryOpExpr expr)1138 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
1139   assert(operandExprStack.size() >= 2);
1140   const auto &rhs = operandExprStack.back();
1141   auto &lhs = operandExprStack[operandExprStack.size() - 2];
1142   assert(lhs.size() == rhs.size());
1143   // Update the LHS in place.
1144   for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1145     lhs[i] += rhs[i];
1146   }
1147   // Pop off the RHS.
1148   operandExprStack.pop_back();
1149 }
1150 
1151 //
1152 // t = expr mod c   <=>  t = expr - c*q and c*q <= expr <= c*q + c - 1
1153 //
1154 // A mod expression "expr mod c" is thus flattened by introducing a new local
1155 // variable q (= expr floordiv c), such that expr mod c is replaced with
1156 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1157 //
1158 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1159 // introduce a local variable m (= expr mod symbolic_expr), and the affine
1160 // expression expr mod symbolic_expr is added to `localExprs`.
visitModExpr(AffineBinaryOpExpr expr)1161 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1162   assert(operandExprStack.size() >= 2);
1163 
1164   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1165   operandExprStack.pop_back();
1166   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1167   MLIRContext *context = expr.getContext();
1168 
1169   // Flatten semi affine modulo expressions by introducing a local
1170   // variable in place of the modulo value, and the affine expression
1171   // corresponding to the quantifier is added to `localExprs`.
1172   if (!expr.getRHS().isa<AffineConstantExpr>()) {
1173     AffineExpr dividendExpr = getAffineExprFromFlatForm(
1174         lhs, numDims, numSymbols, localExprs, context);
1175     AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1176                                                        localExprs, context);
1177     AffineExpr modExpr = dividendExpr % divisorExpr;
1178     addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
1179     return;
1180   }
1181 
1182   int64_t rhsConst = rhs[getConstantIndex()];
1183   // TODO: handle modulo by zero case when this issue is fixed
1184   // at the other places in the IR.
1185   assert(rhsConst > 0 && "RHS constant has to be positive");
1186 
1187   // Check if the LHS expression is a multiple of modulo factor.
1188   unsigned i, e;
1189   for (i = 0, e = lhs.size(); i < e; i++)
1190     if (lhs[i] % rhsConst != 0)
1191       break;
1192   // If yes, modulo expression here simplifies to zero.
1193   if (i == lhs.size()) {
1194     std::fill(lhs.begin(), lhs.end(), 0);
1195     return;
1196   }
1197 
1198   // Add a local variable for the quotient, i.e., expr % c is replaced by
1199   // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1200   // the GCD of expr and c.
1201   SmallVector<int64_t, 8> floorDividend(lhs);
1202   uint64_t gcd = rhsConst;
1203   for (unsigned i = 0, e = lhs.size(); i < e; i++)
1204     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1205   // Simplify the numerator and the denominator.
1206   if (gcd != 1) {
1207     for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
1208       floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
1209   }
1210   int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1211 
1212   // Construct the AffineExpr form of the floordiv to store in localExprs.
1213 
1214   AffineExpr dividendExpr = getAffineExprFromFlatForm(
1215       floorDividend, numDims, numSymbols, localExprs, context);
1216   AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1217   AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1218   int loc;
1219   if ((loc = findLocalId(floorDivExpr)) == -1) {
1220     addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1221     // Set result at top of stack to "lhs - rhsConst * q".
1222     lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1223   } else {
1224     // Reuse the existing local id.
1225     lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1226   }
1227 }
1228 
visitCeilDivExpr(AffineBinaryOpExpr expr)1229 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
1230   visitDivExpr(expr, /*isCeil=*/true);
1231 }
visitFloorDivExpr(AffineBinaryOpExpr expr)1232 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
1233   visitDivExpr(expr, /*isCeil=*/false);
1234 }
1235 
visitDimExpr(AffineDimExpr expr)1236 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
1237   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1238   auto &eq = operandExprStack.back();
1239   assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1240   eq[getDimStartIndex() + expr.getPosition()] = 1;
1241 }
1242 
visitSymbolExpr(AffineSymbolExpr expr)1243 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
1244   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1245   auto &eq = operandExprStack.back();
1246   assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1247   eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1248 }
1249 
visitConstantExpr(AffineConstantExpr expr)1250 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
1251   operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1252   auto &eq = operandExprStack.back();
1253   eq[getConstantIndex()] = expr.getValue();
1254 }
1255 
addLocalVariableSemiAffine(AffineExpr expr,SmallVectorImpl<int64_t> & result,unsigned long resultSize)1256 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1257     AffineExpr expr, SmallVectorImpl<int64_t> &result,
1258     unsigned long resultSize) {
1259   assert(result.size() == resultSize &&
1260          "`result` vector passed is not of correct size");
1261   int loc;
1262   if ((loc = findLocalId(expr)) == -1)
1263     addLocalIdSemiAffine(expr);
1264   std::fill(result.begin(), result.end(), 0);
1265   if (loc == -1)
1266     result[getLocalVarStartIndex() + numLocals - 1] = 1;
1267   else
1268     result[getLocalVarStartIndex() + loc] = 1;
1269 }
1270 
1271 // t = expr floordiv c   <=> t = q, c * q <= expr <= c * q + c - 1
1272 // A floordiv is thus flattened by introducing a new local variable q, and
1273 // replacing that expression with 'q' while adding the constraints
1274 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1275 // FlatAffineConstraints::addLocalFloorDiv).
1276 //
1277 // A ceildiv is similarly flattened:
1278 // t = expr ceildiv c   <=> t =  (expr + c - 1) floordiv c
1279 //
1280 // In case of semi affine division expressions, t = expr floordiv symbolic_expr
1281 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1282 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1283 // `localExprs`.
visitDivExpr(AffineBinaryOpExpr expr,bool isCeil)1284 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1285                                              bool isCeil) {
1286   assert(operandExprStack.size() >= 2);
1287 
1288   MLIRContext *context = expr.getContext();
1289   SmallVector<int64_t, 8> rhs = operandExprStack.back();
1290   operandExprStack.pop_back();
1291   SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1292 
1293   // Flatten semi affine division expressions by introducing a local
1294   // variable in place of the quotient, and the affine expression corresponding
1295   // to the quantifier is added to `localExprs`.
1296   if (!expr.getRHS().isa<AffineConstantExpr>()) {
1297     AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
1298                                              localExprs, context);
1299     AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1300                                              localExprs, context);
1301     AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1302     addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
1303     return;
1304   }
1305 
1306   // This is a pure affine expr; the RHS is a positive constant.
1307   int64_t rhsConst = rhs[getConstantIndex()];
1308   // TODO: handle division by zero at the same time the issue is
1309   // fixed at other places.
1310   assert(rhsConst > 0 && "RHS constant has to be positive");
1311 
1312   // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1313   // common divisors of the numerator and denominator.
1314   uint64_t gcd = std::abs(rhsConst);
1315   for (unsigned i = 0, e = lhs.size(); i < e; i++)
1316     gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1317   // Simplify the numerator and the denominator.
1318   if (gcd != 1) {
1319     for (unsigned i = 0, e = lhs.size(); i < e; i++)
1320       lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1321   }
1322   int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1323   // If the divisor becomes 1, the updated LHS is the result. (The
1324   // divisor can't be negative since rhsConst is positive).
1325   if (divisor == 1)
1326     return;
1327 
1328   // If the divisor cannot be simplified to one, we will have to retain
1329   // the ceil/floor expr (simplified up until here). Add an existential
1330   // quantifier to express its result, i.e., expr1 div expr2 is replaced
1331   // by a new identifier, q.
1332   AffineExpr a =
1333       getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context);
1334   AffineExpr b = getAffineConstantExpr(divisor, context);
1335 
1336   int loc;
1337   AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1338   if ((loc = findLocalId(divExpr)) == -1) {
1339     if (!isCeil) {
1340       SmallVector<int64_t, 8> dividend(lhs);
1341       addLocalFloorDivId(dividend, divisor, divExpr);
1342     } else {
1343       // lhs ceildiv c <=>  (lhs + c - 1) floordiv c
1344       SmallVector<int64_t, 8> dividend(lhs);
1345       dividend.back() += divisor - 1;
1346       addLocalFloorDivId(dividend, divisor, divExpr);
1347     }
1348   }
1349   // Set the expression on stack to the local var introduced to capture the
1350   // result of the division (floor or ceil).
1351   std::fill(lhs.begin(), lhs.end(), 0);
1352   if (loc == -1)
1353     lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1354   else
1355     lhs[getLocalVarStartIndex() + loc] = 1;
1356 }
1357 
1358 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1359 // The local identifier added is always a floordiv of a pure add/mul affine
1360 // function of other identifiers, coefficients of which are specified in
1361 // dividend and with respect to a positive constant divisor. localExpr is the
1362 // simplified tree expression (AffineExpr) corresponding to the quantifier.
addLocalFloorDivId(ArrayRef<int64_t> dividend,int64_t divisor,AffineExpr localExpr)1363 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1364                                                    int64_t divisor,
1365                                                    AffineExpr localExpr) {
1366   assert(divisor > 0 && "positive constant divisor expected");
1367   for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1368     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1369   localExprs.push_back(localExpr);
1370   numLocals++;
1371   // dividend and divisor are not used here; an override of this method uses it.
1372 }
1373 
addLocalIdSemiAffine(AffineExpr localExpr)1374 void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) {
1375   for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1376     subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1377   localExprs.push_back(localExpr);
1378   ++numLocals;
1379 }
1380 
findLocalId(AffineExpr localExpr)1381 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1382   SmallVectorImpl<AffineExpr>::iterator it;
1383   if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1384     return -1;
1385   return it - localExprs.begin();
1386 }
1387 
1388 /// Simplify the affine expression by flattening it and reconstructing it.
simplifyAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols)1389 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1390                                     unsigned numSymbols) {
1391   // Simplify semi-affine expressions separately.
1392   if (!expr.isPureAffine())
1393     expr = simplifySemiAffine(expr);
1394 
1395   SimpleAffineExprFlattener flattener(numDims, numSymbols);
1396   flattener.walkPostOrder(expr);
1397   ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1398   if (!expr.isPureAffine() &&
1399       expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1400                                         flattener.localExprs,
1401                                         expr.getContext()))
1402     return expr;
1403   AffineExpr simplifiedExpr =
1404       expr.isPureAffine()
1405           ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1406                                       flattener.localExprs, expr.getContext())
1407           : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1408                                           flattener.localExprs,
1409                                           expr.getContext());
1410 
1411   flattener.operandExprStack.pop_back();
1412   assert(flattener.operandExprStack.empty());
1413   return simplifiedExpr;
1414 }
1415