1 //===- AffineExpr.cpp - MLIR Affine Expr Classes --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include <utility>
10
11 #include "AffineExprDetail.h"
12 #include "mlir/IR/AffineExpr.h"
13 #include "mlir/IR/AffineExprVisitor.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/Support/MathExtras.h"
17 #include "mlir/Support/TypeID.h"
18 #include "llvm/ADT/STLExtras.h"
19
20 using namespace mlir;
21 using namespace mlir::detail;
22
getContext() const23 MLIRContext *AffineExpr::getContext() const { return expr->context; }
24
getKind() const25 AffineExprKind AffineExpr::getKind() const { return expr->kind; }
26
27 /// Walk all of the AffineExprs in this subgraph in postorder.
walk(std::function<void (AffineExpr)> callback) const28 void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
29 struct AffineExprWalker : public AffineExprVisitor<AffineExprWalker> {
30 std::function<void(AffineExpr)> callback;
31
32 AffineExprWalker(std::function<void(AffineExpr)> callback)
33 : callback(std::move(callback)) {}
34
35 void visitAffineBinaryOpExpr(AffineBinaryOpExpr expr) { callback(expr); }
36 void visitConstantExpr(AffineConstantExpr expr) { callback(expr); }
37 void visitDimExpr(AffineDimExpr expr) { callback(expr); }
38 void visitSymbolExpr(AffineSymbolExpr expr) { callback(expr); }
39 };
40
41 AffineExprWalker(std::move(callback)).walkPostOrder(*this);
42 }
43
44 // Dispatch affine expression construction based on kind.
getAffineBinaryOpExpr(AffineExprKind kind,AffineExpr lhs,AffineExpr rhs)45 AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
46 AffineExpr rhs) {
47 if (kind == AffineExprKind::Add)
48 return lhs + rhs;
49 if (kind == AffineExprKind::Mul)
50 return lhs * rhs;
51 if (kind == AffineExprKind::FloorDiv)
52 return lhs.floorDiv(rhs);
53 if (kind == AffineExprKind::CeilDiv)
54 return lhs.ceilDiv(rhs);
55 if (kind == AffineExprKind::Mod)
56 return lhs % rhs;
57
58 llvm_unreachable("unknown binary operation on affine expressions");
59 }
60
61 /// This method substitutes any uses of dimensions and symbols (e.g.
62 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
63 AffineExpr
replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,ArrayRef<AffineExpr> symReplacements) const64 AffineExpr::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
65 ArrayRef<AffineExpr> symReplacements) const {
66 switch (getKind()) {
67 case AffineExprKind::Constant:
68 return *this;
69 case AffineExprKind::DimId: {
70 unsigned dimId = cast<AffineDimExpr>().getPosition();
71 if (dimId >= dimReplacements.size())
72 return *this;
73 return dimReplacements[dimId];
74 }
75 case AffineExprKind::SymbolId: {
76 unsigned symId = cast<AffineSymbolExpr>().getPosition();
77 if (symId >= symReplacements.size())
78 return *this;
79 return symReplacements[symId];
80 }
81 case AffineExprKind::Add:
82 case AffineExprKind::Mul:
83 case AffineExprKind::FloorDiv:
84 case AffineExprKind::CeilDiv:
85 case AffineExprKind::Mod:
86 auto binOp = cast<AffineBinaryOpExpr>();
87 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
88 auto newLHS = lhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
89 auto newRHS = rhs.replaceDimsAndSymbols(dimReplacements, symReplacements);
90 if (newLHS == lhs && newRHS == rhs)
91 return *this;
92 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
93 }
94 llvm_unreachable("Unknown AffineExpr");
95 }
96
replaceDims(ArrayRef<AffineExpr> dimReplacements) const97 AffineExpr AffineExpr::replaceDims(ArrayRef<AffineExpr> dimReplacements) const {
98 return replaceDimsAndSymbols(dimReplacements, {});
99 }
100
101 AffineExpr
replaceSymbols(ArrayRef<AffineExpr> symReplacements) const102 AffineExpr::replaceSymbols(ArrayRef<AffineExpr> symReplacements) const {
103 return replaceDimsAndSymbols({}, symReplacements);
104 }
105
106 /// Replace dims[offset ... numDims)
107 /// by dims[offset + shift ... shift + numDims).
shiftDims(unsigned numDims,unsigned shift,unsigned offset) const108 AffineExpr AffineExpr::shiftDims(unsigned numDims, unsigned shift,
109 unsigned offset) const {
110 SmallVector<AffineExpr, 4> dims;
111 for (unsigned idx = 0; idx < offset; ++idx)
112 dims.push_back(getAffineDimExpr(idx, getContext()));
113 for (unsigned idx = offset; idx < numDims; ++idx)
114 dims.push_back(getAffineDimExpr(idx + shift, getContext()));
115 return replaceDimsAndSymbols(dims, {});
116 }
117
118 /// Replace symbols[offset ... numSymbols)
119 /// by symbols[offset + shift ... shift + numSymbols).
shiftSymbols(unsigned numSymbols,unsigned shift,unsigned offset) const120 AffineExpr AffineExpr::shiftSymbols(unsigned numSymbols, unsigned shift,
121 unsigned offset) const {
122 SmallVector<AffineExpr, 4> symbols;
123 for (unsigned idx = 0; idx < offset; ++idx)
124 symbols.push_back(getAffineSymbolExpr(idx, getContext()));
125 for (unsigned idx = offset; idx < numSymbols; ++idx)
126 symbols.push_back(getAffineSymbolExpr(idx + shift, getContext()));
127 return replaceDimsAndSymbols({}, symbols);
128 }
129
130 /// Sparse replace method. Return the modified expression tree.
131 AffineExpr
replace(const DenseMap<AffineExpr,AffineExpr> & map) const132 AffineExpr::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
133 auto it = map.find(*this);
134 if (it != map.end())
135 return it->second;
136 switch (getKind()) {
137 default:
138 return *this;
139 case AffineExprKind::Add:
140 case AffineExprKind::Mul:
141 case AffineExprKind::FloorDiv:
142 case AffineExprKind::CeilDiv:
143 case AffineExprKind::Mod:
144 auto binOp = cast<AffineBinaryOpExpr>();
145 auto lhs = binOp.getLHS(), rhs = binOp.getRHS();
146 auto newLHS = lhs.replace(map);
147 auto newRHS = rhs.replace(map);
148 if (newLHS == lhs && newRHS == rhs)
149 return *this;
150 return getAffineBinaryOpExpr(getKind(), newLHS, newRHS);
151 }
152 llvm_unreachable("Unknown AffineExpr");
153 }
154
155 /// Sparse replace method. Return the modified expression tree.
replace(AffineExpr expr,AffineExpr replacement) const156 AffineExpr AffineExpr::replace(AffineExpr expr, AffineExpr replacement) const {
157 DenseMap<AffineExpr, AffineExpr> map;
158 map.insert(std::make_pair(expr, replacement));
159 return replace(map);
160 }
161 /// Returns true if this expression is made out of only symbols and
162 /// constants (no dimensional identifiers).
isSymbolicOrConstant() const163 bool AffineExpr::isSymbolicOrConstant() const {
164 switch (getKind()) {
165 case AffineExprKind::Constant:
166 return true;
167 case AffineExprKind::DimId:
168 return false;
169 case AffineExprKind::SymbolId:
170 return true;
171
172 case AffineExprKind::Add:
173 case AffineExprKind::Mul:
174 case AffineExprKind::FloorDiv:
175 case AffineExprKind::CeilDiv:
176 case AffineExprKind::Mod: {
177 auto expr = this->cast<AffineBinaryOpExpr>();
178 return expr.getLHS().isSymbolicOrConstant() &&
179 expr.getRHS().isSymbolicOrConstant();
180 }
181 }
182 llvm_unreachable("Unknown AffineExpr");
183 }
184
185 /// Returns true if this is a pure affine expression, i.e., multiplication,
186 /// floordiv, ceildiv, and mod is only allowed w.r.t constants.
isPureAffine() const187 bool AffineExpr::isPureAffine() const {
188 switch (getKind()) {
189 case AffineExprKind::SymbolId:
190 case AffineExprKind::DimId:
191 case AffineExprKind::Constant:
192 return true;
193 case AffineExprKind::Add: {
194 auto op = cast<AffineBinaryOpExpr>();
195 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine();
196 }
197
198 case AffineExprKind::Mul: {
199 // TODO: Canonicalize the constants in binary operators to the RHS when
200 // possible, allowing this to merge into the next case.
201 auto op = cast<AffineBinaryOpExpr>();
202 return op.getLHS().isPureAffine() && op.getRHS().isPureAffine() &&
203 (op.getLHS().template isa<AffineConstantExpr>() ||
204 op.getRHS().template isa<AffineConstantExpr>());
205 }
206 case AffineExprKind::FloorDiv:
207 case AffineExprKind::CeilDiv:
208 case AffineExprKind::Mod: {
209 auto op = cast<AffineBinaryOpExpr>();
210 return op.getLHS().isPureAffine() &&
211 op.getRHS().template isa<AffineConstantExpr>();
212 }
213 }
214 llvm_unreachable("Unknown AffineExpr");
215 }
216
217 // Returns the greatest known integral divisor of this affine expression.
getLargestKnownDivisor() const218 int64_t AffineExpr::getLargestKnownDivisor() const {
219 AffineBinaryOpExpr binExpr(nullptr);
220 switch (getKind()) {
221 case AffineExprKind::CeilDiv:
222 LLVM_FALLTHROUGH;
223 case AffineExprKind::DimId:
224 case AffineExprKind::FloorDiv:
225 case AffineExprKind::SymbolId:
226 return 1;
227 case AffineExprKind::Constant:
228 return std::abs(this->cast<AffineConstantExpr>().getValue());
229 case AffineExprKind::Mul: {
230 binExpr = this->cast<AffineBinaryOpExpr>();
231 return binExpr.getLHS().getLargestKnownDivisor() *
232 binExpr.getRHS().getLargestKnownDivisor();
233 }
234 case AffineExprKind::Add:
235 LLVM_FALLTHROUGH;
236 case AffineExprKind::Mod: {
237 binExpr = cast<AffineBinaryOpExpr>();
238 return llvm::GreatestCommonDivisor64(
239 binExpr.getLHS().getLargestKnownDivisor(),
240 binExpr.getRHS().getLargestKnownDivisor());
241 }
242 }
243 llvm_unreachable("Unknown AffineExpr");
244 }
245
isMultipleOf(int64_t factor) const246 bool AffineExpr::isMultipleOf(int64_t factor) const {
247 AffineBinaryOpExpr binExpr(nullptr);
248 uint64_t l, u;
249 switch (getKind()) {
250 case AffineExprKind::SymbolId:
251 LLVM_FALLTHROUGH;
252 case AffineExprKind::DimId:
253 return factor * factor == 1;
254 case AffineExprKind::Constant:
255 return cast<AffineConstantExpr>().getValue() % factor == 0;
256 case AffineExprKind::Mul: {
257 binExpr = cast<AffineBinaryOpExpr>();
258 // It's probably not worth optimizing this further (to not traverse the
259 // whole sub-tree under - it that would require a version of isMultipleOf
260 // that on a 'false' return also returns the largest known divisor).
261 return (l = binExpr.getLHS().getLargestKnownDivisor()) % factor == 0 ||
262 (u = binExpr.getRHS().getLargestKnownDivisor()) % factor == 0 ||
263 (l * u) % factor == 0;
264 }
265 case AffineExprKind::Add:
266 case AffineExprKind::FloorDiv:
267 case AffineExprKind::CeilDiv:
268 case AffineExprKind::Mod: {
269 binExpr = cast<AffineBinaryOpExpr>();
270 return llvm::GreatestCommonDivisor64(
271 binExpr.getLHS().getLargestKnownDivisor(),
272 binExpr.getRHS().getLargestKnownDivisor()) %
273 factor ==
274 0;
275 }
276 }
277 llvm_unreachable("Unknown AffineExpr");
278 }
279
isFunctionOfDim(unsigned position) const280 bool AffineExpr::isFunctionOfDim(unsigned position) const {
281 if (getKind() == AffineExprKind::DimId) {
282 return *this == mlir::getAffineDimExpr(position, getContext());
283 }
284 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
285 return expr.getLHS().isFunctionOfDim(position) ||
286 expr.getRHS().isFunctionOfDim(position);
287 }
288 return false;
289 }
290
isFunctionOfSymbol(unsigned position) const291 bool AffineExpr::isFunctionOfSymbol(unsigned position) const {
292 if (getKind() == AffineExprKind::SymbolId) {
293 return *this == mlir::getAffineSymbolExpr(position, getContext());
294 }
295 if (auto expr = this->dyn_cast<AffineBinaryOpExpr>()) {
296 return expr.getLHS().isFunctionOfSymbol(position) ||
297 expr.getRHS().isFunctionOfSymbol(position);
298 }
299 return false;
300 }
301
AffineBinaryOpExpr(AffineExpr::ImplType * ptr)302 AffineBinaryOpExpr::AffineBinaryOpExpr(AffineExpr::ImplType *ptr)
303 : AffineExpr(ptr) {}
getLHS() const304 AffineExpr AffineBinaryOpExpr::getLHS() const {
305 return static_cast<ImplType *>(expr)->lhs;
306 }
getRHS() const307 AffineExpr AffineBinaryOpExpr::getRHS() const {
308 return static_cast<ImplType *>(expr)->rhs;
309 }
310
AffineDimExpr(AffineExpr::ImplType * ptr)311 AffineDimExpr::AffineDimExpr(AffineExpr::ImplType *ptr) : AffineExpr(ptr) {}
getPosition() const312 unsigned AffineDimExpr::getPosition() const {
313 return static_cast<ImplType *>(expr)->position;
314 }
315
316 /// Returns true if the expression is divisible by the given symbol with
317 /// position `symbolPos`. The argument `opKind` specifies here what kind of
318 /// division or mod operation called this division. It helps in implementing the
319 /// commutative property of the floordiv and ceildiv operations. If the argument
320 ///`exprKind` is floordiv and `expr` is also a binary expression of a floordiv
321 /// operation, then the commutative property can be used otherwise, the floordiv
322 /// operation is not divisible. The same argument holds for ceildiv operation.
isDivisibleBySymbol(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)323 static bool isDivisibleBySymbol(AffineExpr expr, unsigned symbolPos,
324 AffineExprKind opKind) {
325 // The argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
326 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
327 opKind == AffineExprKind::CeilDiv) &&
328 "unexpected opKind");
329 switch (expr.getKind()) {
330 case AffineExprKind::Constant:
331 return expr.cast<AffineConstantExpr>().getValue() == 0;
332 case AffineExprKind::DimId:
333 return false;
334 case AffineExprKind::SymbolId:
335 return (expr.cast<AffineSymbolExpr>().getPosition() == symbolPos);
336 // Checks divisibility by the given symbol for both operands.
337 case AffineExprKind::Add: {
338 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
339 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) &&
340 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
341 }
342 // Checks divisibility by the given symbol for both operands. Consider the
343 // expression `(((s1*s0) floordiv w) mod ((s1 * s2) floordiv p)) floordiv s1`,
344 // this is a division by s1 and both the operands of modulo are divisible by
345 // s1 but it is not divisible by s1 always. The third argument is
346 // `AffineExprKind::Mod` for this reason.
347 case AffineExprKind::Mod: {
348 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
349 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos,
350 AffineExprKind::Mod) &&
351 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos,
352 AffineExprKind::Mod);
353 }
354 // Checks if any of the operand divisible by the given symbol.
355 case AffineExprKind::Mul: {
356 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
357 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind) ||
358 isDivisibleBySymbol(binaryExpr.getRHS(), symbolPos, opKind);
359 }
360 // Floordiv and ceildiv are divisible by the given symbol when the first
361 // operand is divisible, and the affine expression kind of the argument expr
362 // is same as the argument `opKind`. This can be inferred from commutative
363 // property of floordiv and ceildiv operations and are as follow:
364 // (exp1 floordiv exp2) floordiv exp3 = (exp1 floordiv exp3) floordiv exp2
365 // (exp1 ceildiv exp2) ceildiv exp3 = (exp1 ceildiv exp3) ceildiv expr2
366 // It will fail if operations are not same. For example:
367 // (exps1 ceildiv exp2) floordiv exp3 can not be simplified.
368 case AffineExprKind::FloorDiv:
369 case AffineExprKind::CeilDiv: {
370 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
371 if (opKind != expr.getKind())
372 return false;
373 return isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind());
374 }
375 }
376 llvm_unreachable("Unknown AffineExpr");
377 }
378
379 /// Divides the given expression by the given symbol at position `symbolPos`. It
380 /// considers the divisibility condition is checked before calling itself. A
381 /// null expression is returned whenever the divisibility condition fails.
symbolicDivide(AffineExpr expr,unsigned symbolPos,AffineExprKind opKind)382 static AffineExpr symbolicDivide(AffineExpr expr, unsigned symbolPos,
383 AffineExprKind opKind) {
384 // THe argument `opKind` can either be Modulo, Floordiv or Ceildiv only.
385 assert((opKind == AffineExprKind::Mod || opKind == AffineExprKind::FloorDiv ||
386 opKind == AffineExprKind::CeilDiv) &&
387 "unexpected opKind");
388 switch (expr.getKind()) {
389 case AffineExprKind::Constant:
390 if (expr.cast<AffineConstantExpr>().getValue() != 0)
391 return nullptr;
392 return getAffineConstantExpr(0, expr.getContext());
393 case AffineExprKind::DimId:
394 return nullptr;
395 case AffineExprKind::SymbolId:
396 return getAffineConstantExpr(1, expr.getContext());
397 // Dividing both operands by the given symbol.
398 case AffineExprKind::Add: {
399 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
400 return getAffineBinaryOpExpr(
401 expr.getKind(), symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind),
402 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind));
403 }
404 // Dividing both operands by the given symbol.
405 case AffineExprKind::Mod: {
406 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
407 return getAffineBinaryOpExpr(
408 expr.getKind(),
409 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
410 symbolicDivide(binaryExpr.getRHS(), symbolPos, expr.getKind()));
411 }
412 // Dividing any of the operand by the given symbol.
413 case AffineExprKind::Mul: {
414 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
415 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, opKind))
416 return binaryExpr.getLHS() *
417 symbolicDivide(binaryExpr.getRHS(), symbolPos, opKind);
418 return symbolicDivide(binaryExpr.getLHS(), symbolPos, opKind) *
419 binaryExpr.getRHS();
420 }
421 // Dividing first operand only by the given symbol.
422 case AffineExprKind::FloorDiv:
423 case AffineExprKind::CeilDiv: {
424 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
425 return getAffineBinaryOpExpr(
426 expr.getKind(),
427 symbolicDivide(binaryExpr.getLHS(), symbolPos, expr.getKind()),
428 binaryExpr.getRHS());
429 }
430 }
431 llvm_unreachable("Unknown AffineExpr");
432 }
433
434 /// Simplify a semi-affine expression by handling modulo, floordiv, or ceildiv
435 /// operations when the second operand simplifies to a symbol and the first
436 /// operand is divisible by that symbol. It can be applied to any semi-affine
437 /// expression. Returned expression can either be a semi-affine or pure affine
438 /// expression.
simplifySemiAffine(AffineExpr expr)439 static AffineExpr simplifySemiAffine(AffineExpr expr) {
440 switch (expr.getKind()) {
441 case AffineExprKind::Constant:
442 case AffineExprKind::DimId:
443 case AffineExprKind::SymbolId:
444 return expr;
445 case AffineExprKind::Add:
446 case AffineExprKind::Mul: {
447 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
448 return getAffineBinaryOpExpr(expr.getKind(),
449 simplifySemiAffine(binaryExpr.getLHS()),
450 simplifySemiAffine(binaryExpr.getRHS()));
451 }
452 // Check if the simplification of the second operand is a symbol, and the
453 // first operand is divisible by it. If the operation is a modulo, a constant
454 // zero expression is returned. In the case of floordiv and ceildiv, the
455 // symbol from the simplification of the second operand divides the first
456 // operand. Otherwise, simplification is not possible.
457 case AffineExprKind::FloorDiv:
458 case AffineExprKind::CeilDiv:
459 case AffineExprKind::Mod: {
460 AffineBinaryOpExpr binaryExpr = expr.cast<AffineBinaryOpExpr>();
461 AffineExpr sLHS = simplifySemiAffine(binaryExpr.getLHS());
462 AffineExpr sRHS = simplifySemiAffine(binaryExpr.getRHS());
463 AffineSymbolExpr symbolExpr =
464 simplifySemiAffine(binaryExpr.getRHS()).dyn_cast<AffineSymbolExpr>();
465 if (!symbolExpr)
466 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
467 unsigned symbolPos = symbolExpr.getPosition();
468 if (!isDivisibleBySymbol(binaryExpr.getLHS(), symbolPos, expr.getKind()))
469 return getAffineBinaryOpExpr(expr.getKind(), sLHS, sRHS);
470 if (expr.getKind() == AffineExprKind::Mod)
471 return getAffineConstantExpr(0, expr.getContext());
472 return symbolicDivide(sLHS, symbolPos, expr.getKind());
473 }
474 }
475 llvm_unreachable("Unknown AffineExpr");
476 }
477
getAffineDimOrSymbol(AffineExprKind kind,unsigned position,MLIRContext * context)478 static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
479 MLIRContext *context) {
480 auto assignCtx = [context](AffineDimExprStorage *storage) {
481 storage->context = context;
482 };
483
484 StorageUniquer &uniquer = context->getAffineUniquer();
485 return uniquer.get<AffineDimExprStorage>(
486 assignCtx, static_cast<unsigned>(kind), position);
487 }
488
getAffineDimExpr(unsigned position,MLIRContext * context)489 AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
490 return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
491 }
492
AffineSymbolExpr(AffineExpr::ImplType * ptr)493 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
494 : AffineExpr(ptr) {}
getPosition() const495 unsigned AffineSymbolExpr::getPosition() const {
496 return static_cast<ImplType *>(expr)->position;
497 }
498
getAffineSymbolExpr(unsigned position,MLIRContext * context)499 AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
500 return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
501 ;
502 }
503
AffineConstantExpr(AffineExpr::ImplType * ptr)504 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
505 : AffineExpr(ptr) {}
getValue() const506 int64_t AffineConstantExpr::getValue() const {
507 return static_cast<ImplType *>(expr)->constant;
508 }
509
operator ==(int64_t v) const510 bool AffineExpr::operator==(int64_t v) const {
511 return *this == getAffineConstantExpr(v, getContext());
512 }
513
getAffineConstantExpr(int64_t constant,MLIRContext * context)514 AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
515 auto assignCtx = [context](AffineConstantExprStorage *storage) {
516 storage->context = context;
517 };
518
519 StorageUniquer &uniquer = context->getAffineUniquer();
520 return uniquer.get<AffineConstantExprStorage>(assignCtx, constant);
521 }
522
523 /// Simplify add expression. Return nullptr if it can't be simplified.
simplifyAdd(AffineExpr lhs,AffineExpr rhs)524 static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
525 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
526 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
527 // Fold if both LHS, RHS are a constant.
528 if (lhsConst && rhsConst)
529 return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
530 lhs.getContext());
531
532 // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
533 // If only one of them is a symbolic expressions, make it the RHS.
534 if (lhs.isa<AffineConstantExpr>() ||
535 (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
536 return rhs + lhs;
537 }
538
539 // At this point, if there was a constant, it would be on the right.
540
541 // Addition with a zero is a noop, return the other input.
542 if (rhsConst) {
543 if (rhsConst.getValue() == 0)
544 return lhs;
545 }
546 // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
547 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
548 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
549 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
550 return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
551 }
552
553 // Detect "c1 * expr + c_2 * expr" as "(c1 + c2) * expr".
554 // c1 is rRhsConst, c2 is rLhsConst; firstExpr, secondExpr are their
555 // respective multiplicands.
556 Optional<int64_t> rLhsConst, rRhsConst;
557 AffineExpr firstExpr, secondExpr;
558 AffineConstantExpr rLhsConstExpr;
559 auto lBinOpExpr = lhs.dyn_cast<AffineBinaryOpExpr>();
560 if (lBinOpExpr && lBinOpExpr.getKind() == AffineExprKind::Mul &&
561 (rLhsConstExpr = lBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
562 rLhsConst = rLhsConstExpr.getValue();
563 firstExpr = lBinOpExpr.getLHS();
564 } else {
565 rLhsConst = 1;
566 firstExpr = lhs;
567 }
568
569 auto rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
570 AffineConstantExpr rRhsConstExpr;
571 if (rBinOpExpr && rBinOpExpr.getKind() == AffineExprKind::Mul &&
572 (rRhsConstExpr = rBinOpExpr.getRHS().dyn_cast<AffineConstantExpr>())) {
573 rRhsConst = rRhsConstExpr.getValue();
574 secondExpr = rBinOpExpr.getLHS();
575 } else {
576 rRhsConst = 1;
577 secondExpr = rhs;
578 }
579
580 if (rLhsConst && rRhsConst && firstExpr == secondExpr)
581 return getAffineBinaryOpExpr(
582 AffineExprKind::Mul, firstExpr,
583 getAffineConstantExpr(*rLhsConst + *rRhsConst, lhs.getContext()));
584
585 // When doing successive additions, bring constant to the right: turn (d0 + 2)
586 // + d1 into (d0 + d1) + 2.
587 if (lBin && lBin.getKind() == AffineExprKind::Add) {
588 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
589 return lBin.getLHS() + rhs + lrhs;
590 }
591 }
592
593 // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
594 // q may be a constant or symbolic expression. This leads to a much more
595 // efficient form when 'c' is a power of two, and in general a more compact
596 // and readable form.
597
598 // Process '(expr floordiv c) * (-c)'.
599 if (!rBinOpExpr)
600 return nullptr;
601
602 auto lrhs = rBinOpExpr.getLHS();
603 auto rrhs = rBinOpExpr.getRHS();
604
605 AffineExpr llrhs, rlrhs;
606
607 // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
608 // symbolic expression.
609 auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
610 // Check rrhsConstOpExpr = -1.
611 auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>();
612 if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
613 lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
614 // Check llrhs = expr floordiv q.
615 llrhs = lrhsBinOpExpr.getLHS();
616 // Check rlrhs = q.
617 rlrhs = lrhsBinOpExpr.getRHS();
618 auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>();
619 if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
620 return nullptr;
621 if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
622 return lhs % rlrhs;
623 }
624
625 // Process lrhs, which is 'expr floordiv c'.
626 AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
627 if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
628 return nullptr;
629
630 llrhs = lrBinOpExpr.getLHS();
631 rlrhs = lrBinOpExpr.getRHS();
632
633 if (lhs == llrhs && rlrhs == -rrhs) {
634 return lhs % rlrhs;
635 }
636 return nullptr;
637 }
638
operator +(int64_t v) const639 AffineExpr AffineExpr::operator+(int64_t v) const {
640 return *this + getAffineConstantExpr(v, getContext());
641 }
operator +(AffineExpr other) const642 AffineExpr AffineExpr::operator+(AffineExpr other) const {
643 if (auto simplified = simplifyAdd(*this, other))
644 return simplified;
645
646 StorageUniquer &uniquer = getContext()->getAffineUniquer();
647 return uniquer.get<AffineBinaryOpExprStorage>(
648 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
649 }
650
651 /// Simplify a multiply expression. Return nullptr if it can't be simplified.
simplifyMul(AffineExpr lhs,AffineExpr rhs)652 static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
653 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
654 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
655
656 if (lhsConst && rhsConst)
657 return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
658 lhs.getContext());
659
660 assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
661
662 // Canonicalize the mul expression so that the constant/symbolic term is the
663 // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
664 // constant. (Note that a constant is trivially symbolic).
665 if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
666 // At least one of them has to be symbolic.
667 return rhs * lhs;
668 }
669
670 // At this point, if there was a constant, it would be on the right.
671
672 // Multiplication with a one is a noop, return the other input.
673 if (rhsConst) {
674 if (rhsConst.getValue() == 1)
675 return lhs;
676 // Multiplication with zero.
677 if (rhsConst.getValue() == 0)
678 return rhsConst;
679 }
680
681 // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
682 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
683 if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
684 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
685 return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
686 }
687
688 // When doing successive multiplication, bring constant to the right: turn (d0
689 // * 2) * d1 into (d0 * d1) * 2.
690 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
691 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
692 return (lBin.getLHS() * rhs) * lrhs;
693 }
694 }
695
696 return nullptr;
697 }
698
operator *(int64_t v) const699 AffineExpr AffineExpr::operator*(int64_t v) const {
700 return *this * getAffineConstantExpr(v, getContext());
701 }
operator *(AffineExpr other) const702 AffineExpr AffineExpr::operator*(AffineExpr other) const {
703 if (auto simplified = simplifyMul(*this, other))
704 return simplified;
705
706 StorageUniquer &uniquer = getContext()->getAffineUniquer();
707 return uniquer.get<AffineBinaryOpExprStorage>(
708 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
709 }
710
711 // Unary minus, delegate to operator*.
operator -() const712 AffineExpr AffineExpr::operator-() const {
713 return *this * getAffineConstantExpr(-1, getContext());
714 }
715
716 // Delegate to operator+.
operator -(int64_t v) const717 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
operator -(AffineExpr other) const718 AffineExpr AffineExpr::operator-(AffineExpr other) const {
719 return *this + (-other);
720 }
721
simplifyFloorDiv(AffineExpr lhs,AffineExpr rhs)722 static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
723 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
724 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
725
726 // mlir floordiv by zero or negative numbers is undefined and preserved as is.
727 if (!rhsConst || rhsConst.getValue() < 1)
728 return nullptr;
729
730 if (lhsConst)
731 return getAffineConstantExpr(
732 floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
733
734 // Fold floordiv of a multiply with a constant that is a multiple of the
735 // divisor. Eg: (i * 128) floordiv 64 = i * 2.
736 if (rhsConst == 1)
737 return lhs;
738
739 // Simplify (expr * const) floordiv divConst when expr is known to be a
740 // multiple of divConst.
741 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
742 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
743 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
744 // rhsConst is known to be a positive constant.
745 if (lrhs.getValue() % rhsConst.getValue() == 0)
746 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
747 }
748 }
749
750 // Simplify (expr1 + expr2) floordiv divConst when either expr1 or expr2 is
751 // known to be a multiple of divConst.
752 if (lBin && lBin.getKind() == AffineExprKind::Add) {
753 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
754 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
755 // rhsConst is known to be a positive constant.
756 if (llhsDiv % rhsConst.getValue() == 0 ||
757 lrhsDiv % rhsConst.getValue() == 0)
758 return lBin.getLHS().floorDiv(rhsConst.getValue()) +
759 lBin.getRHS().floorDiv(rhsConst.getValue());
760 }
761
762 return nullptr;
763 }
764
floorDiv(uint64_t v) const765 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
766 return floorDiv(getAffineConstantExpr(v, getContext()));
767 }
floorDiv(AffineExpr other) const768 AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
769 if (auto simplified = simplifyFloorDiv(*this, other))
770 return simplified;
771
772 StorageUniquer &uniquer = getContext()->getAffineUniquer();
773 return uniquer.get<AffineBinaryOpExprStorage>(
774 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
775 other);
776 }
777
simplifyCeilDiv(AffineExpr lhs,AffineExpr rhs)778 static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
779 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
780 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
781
782 if (!rhsConst || rhsConst.getValue() < 1)
783 return nullptr;
784
785 if (lhsConst)
786 return getAffineConstantExpr(
787 ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
788
789 // Fold ceildiv of a multiply with a constant that is a multiple of the
790 // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
791 if (rhsConst.getValue() == 1)
792 return lhs;
793
794 // Simplify (expr * const) ceildiv divConst when const is known to be a
795 // multiple of divConst.
796 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
797 if (lBin && lBin.getKind() == AffineExprKind::Mul) {
798 if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
799 // rhsConst is known to be a positive constant.
800 if (lrhs.getValue() % rhsConst.getValue() == 0)
801 return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
802 }
803 }
804
805 return nullptr;
806 }
807
ceilDiv(uint64_t v) const808 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
809 return ceilDiv(getAffineConstantExpr(v, getContext()));
810 }
ceilDiv(AffineExpr other) const811 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
812 if (auto simplified = simplifyCeilDiv(*this, other))
813 return simplified;
814
815 StorageUniquer &uniquer = getContext()->getAffineUniquer();
816 return uniquer.get<AffineBinaryOpExprStorage>(
817 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
818 other);
819 }
820
simplifyMod(AffineExpr lhs,AffineExpr rhs)821 static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
822 auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
823 auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
824
825 // mod w.r.t zero or negative numbers is undefined and preserved as is.
826 if (!rhsConst || rhsConst.getValue() < 1)
827 return nullptr;
828
829 if (lhsConst)
830 return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
831 lhs.getContext());
832
833 // Fold modulo of an expression that is known to be a multiple of a constant
834 // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
835 // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
836 if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
837 return getAffineConstantExpr(0, lhs.getContext());
838
839 // Simplify (expr1 + expr2) mod divConst when either expr1 or expr2 is
840 // known to be a multiple of divConst.
841 auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
842 if (lBin && lBin.getKind() == AffineExprKind::Add) {
843 int64_t llhsDiv = lBin.getLHS().getLargestKnownDivisor();
844 int64_t lrhsDiv = lBin.getRHS().getLargestKnownDivisor();
845 // rhsConst is known to be a positive constant.
846 if (llhsDiv % rhsConst.getValue() == 0)
847 return lBin.getRHS() % rhsConst.getValue();
848 if (lrhsDiv % rhsConst.getValue() == 0)
849 return lBin.getLHS() % rhsConst.getValue();
850 }
851
852 // Simplify (e % a) % b to e % b when b evenly divides a
853 if (lBin && lBin.getKind() == AffineExprKind::Mod) {
854 auto intermediate = lBin.getRHS().dyn_cast<AffineConstantExpr>();
855 if (intermediate && intermediate.getValue() >= 1 &&
856 mod(intermediate.getValue(), rhsConst.getValue()) == 0) {
857 return lBin.getLHS() % rhsConst.getValue();
858 }
859 }
860
861 return nullptr;
862 }
863
operator %(uint64_t v) const864 AffineExpr AffineExpr::operator%(uint64_t v) const {
865 return *this % getAffineConstantExpr(v, getContext());
866 }
operator %(AffineExpr other) const867 AffineExpr AffineExpr::operator%(AffineExpr other) const {
868 if (auto simplified = simplifyMod(*this, other))
869 return simplified;
870
871 StorageUniquer &uniquer = getContext()->getAffineUniquer();
872 return uniquer.get<AffineBinaryOpExprStorage>(
873 /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
874 }
875
compose(AffineMap map) const876 AffineExpr AffineExpr::compose(AffineMap map) const {
877 SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
878 map.getResults().end());
879 return replaceDimsAndSymbols(dimReplacements, {});
880 }
operator <<(raw_ostream & os,AffineExpr expr)881 raw_ostream &mlir::operator<<(raw_ostream &os, AffineExpr expr) {
882 expr.print(os);
883 return os;
884 }
885
886 /// Constructs an affine expression from a flat ArrayRef. If there are local
887 /// identifiers (neither dimensional nor symbolic) that appear in the sum of
888 /// products expression, `localExprs` is expected to have the AffineExpr
889 /// for it, and is substituted into. The ArrayRef `flatExprs` is expected to be
890 /// in the format [dims, symbols, locals, constant term].
getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,unsigned numDims,unsigned numSymbols,ArrayRef<AffineExpr> localExprs,MLIRContext * context)891 AffineExpr mlir::getAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
892 unsigned numDims,
893 unsigned numSymbols,
894 ArrayRef<AffineExpr> localExprs,
895 MLIRContext *context) {
896 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
897 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
898 "unexpected number of local expressions");
899
900 auto expr = getAffineConstantExpr(0, context);
901 // Dimensions and symbols.
902 for (unsigned j = 0; j < numDims + numSymbols; j++) {
903 if (flatExprs[j] == 0)
904 continue;
905 auto id = j < numDims ? getAffineDimExpr(j, context)
906 : getAffineSymbolExpr(j - numDims, context);
907 expr = expr + id * flatExprs[j];
908 }
909
910 // Local identifiers.
911 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
912 j++) {
913 if (flatExprs[j] == 0)
914 continue;
915 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
916 expr = expr + term;
917 }
918
919 // Constant term.
920 int64_t constTerm = flatExprs[flatExprs.size() - 1];
921 if (constTerm != 0)
922 expr = expr + constTerm;
923 return expr;
924 }
925
926 /// Constructs a semi-affine expression from a flat ArrayRef. If there are
927 /// local identifiers (neither dimensional nor symbolic) that appear in the sum
928 /// of products expression, `localExprs` is expected to have the AffineExprs for
929 /// it, and is substituted into. The ArrayRef `flatExprs` is expected to be in
930 /// the format [dims, symbols, locals, constant term]. The semi-affine
931 /// expression is constructed in the sorted order of dimension and symbol
932 /// position numbers. Note: local expressions/ids are used for mod, div as well
933 /// as symbolic RHS terms for terms that are not pure affine.
getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,unsigned numDims,unsigned numSymbols,ArrayRef<AffineExpr> localExprs,MLIRContext * context)934 static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
935 unsigned numDims,
936 unsigned numSymbols,
937 ArrayRef<AffineExpr> localExprs,
938 MLIRContext *context) {
939 assert(!flatExprs.empty() && "flatExprs cannot be empty");
940
941 // Assert expected numLocals = flatExprs.size() - numDims - numSymbols - 1.
942 assert(flatExprs.size() - numDims - numSymbols - 1 == localExprs.size() &&
943 "unexpected number of local expressions");
944
945 AffineExpr expr = getAffineConstantExpr(0, context);
946
947 // We design indices as a pair which help us present the semi-affine map as
948 // sum of product where terms are sorted based on dimension or symbol
949 // position: <keyA, keyB> for expressions of the form dimension * symbol,
950 // where keyA is the position number of the dimension and keyB is the
951 // position number of the symbol. For dimensional expressions we set the index
952 // as (position number of the dimension, -1), as we want dimensional
953 // expressions to appear before symbolic and product of dimensional and
954 // symbolic expressions having the dimension with the same position number.
955 // For symbolic expression set the index as (position number of the symbol,
956 // maximum of last dimension and symbol position) number. For example, we want
957 // the expression we are constructing to look something like: d0 + d0 * s0 +
958 // s0 + d1*s1 + s1.
959
960 // Stores the affine expression corresponding to a given index.
961 DenseMap<std::pair<unsigned, signed>, AffineExpr> indexToExprMap;
962 // Stores the constant coefficient value corresponding to a given
963 // dimension, symbol or a non-pure affine expression stored in `localExprs`.
964 DenseMap<std::pair<unsigned, signed>, int64_t> coefficients;
965 // Stores the indices as defined above, and later sorted to produce
966 // the semi-affine expression in the desired form.
967 SmallVector<std::pair<unsigned, signed>, 8> indices;
968
969 // Example: expression = d0 + d0 * s0 + 2 * s0.
970 // indices = [{0,-1}, {0, 0}, {0, 1}]
971 // coefficients = [{{0, -1}, 1}, {{0, 0}, 1}, {{0, 1}, 2}]
972 // indexToExprMap = [{{0, -1}, d0}, {{0, 0}, d0 * s0}, {{0, 1}, s0}]
973
974 // Adds entries to `indexToExprMap`, `coefficients` and `indices`.
975 auto addEntry = [&](std::pair<unsigned, signed> index, int64_t coefficient,
976 AffineExpr expr) {
977 assert(!llvm::is_contained(indices, index) &&
978 "Key is already present in indices vector and overwriting will "
979 "happen in `indexToExprMap` and `coefficients`!");
980
981 indices.push_back(index);
982 coefficients.insert({index, coefficient});
983 indexToExprMap.insert({index, expr});
984 };
985
986 // Design indices for dimensional or symbolic terms, and store the indices,
987 // constant coefficient corresponding to the indices in `coefficients` map,
988 // and affine expression corresponding to indices in `indexToExprMap` map.
989
990 for (unsigned j = 0; j < numDims; ++j) {
991 if (flatExprs[j] == 0)
992 continue;
993 // For dimensional expressions we set the index as <position number of the
994 // dimension, 0>, as we want dimensional expressions to appear before
995 // symbolic ones and products of dimensional and symbolic expressions
996 // having the dimension with the same position number.
997 std::pair<unsigned, signed> indexEntry(j, -1);
998 addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
999 }
1000 for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
1001 if (flatExprs[j] == 0)
1002 continue;
1003 // For symbolic expression set the index as <position number
1004 // of the symbol, max(dimCount, symCount)> number,
1005 // as we want symbolic expressions with the same positional number to
1006 // appear after dimensional expressions having the same positional number.
1007 std::pair<unsigned, signed> indexEntry(j - numDims,
1008 std::max(numDims, numSymbols));
1009 addEntry(indexEntry, flatExprs[j],
1010 getAffineSymbolExpr(j - numDims, context));
1011 }
1012
1013 // Denotes semi-affine product, modulo or division terms, which has been added
1014 // to the `indexToExpr` map.
1015 SmallVector<bool, 4> addedToMap(flatExprs.size() - numDims - numSymbols - 1,
1016 false);
1017 unsigned lhsPos, rhsPos;
1018 // Construct indices for product terms involving dimension, symbol or constant
1019 // as lhs/rhs, and store the indices, constant coefficient corresponding to
1020 // the indices in `coefficients` map, and affine expression corresponding to
1021 // in indices in `indexToExprMap` map.
1022 for (const auto &it : llvm::enumerate(localExprs)) {
1023 AffineExpr expr = it.value();
1024 if (flatExprs[numDims + numSymbols + it.index()] == 0)
1025 continue;
1026 AffineExpr lhs = expr.cast<AffineBinaryOpExpr>().getLHS();
1027 AffineExpr rhs = expr.cast<AffineBinaryOpExpr>().getRHS();
1028 if (!((lhs.isa<AffineDimExpr>() || lhs.isa<AffineSymbolExpr>()) &&
1029 (rhs.isa<AffineDimExpr>() || rhs.isa<AffineSymbolExpr>() ||
1030 rhs.isa<AffineConstantExpr>()))) {
1031 continue;
1032 }
1033 if (rhs.isa<AffineConstantExpr>()) {
1034 // For product/modulo/division expressions, when rhs of modulo/division
1035 // expression is constant, we put 0 in place of keyB, because we want
1036 // them to appear earlier in the semi-affine expression we are
1037 // constructing. When rhs is constant, we place 0 in place of keyB.
1038 if (lhs.isa<AffineDimExpr>()) {
1039 lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1040 std::pair<unsigned, signed> indexEntry(lhsPos, -1);
1041 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1042 expr);
1043 } else {
1044 lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1045 std::pair<unsigned, signed> indexEntry(lhsPos,
1046 std::max(numDims, numSymbols));
1047 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
1048 expr);
1049 }
1050 } else if (lhs.isa<AffineDimExpr>()) {
1051 // For product/modulo/division expressions having lhs as dimension and rhs
1052 // as symbol, we order the terms in the semi-affine expression based on
1053 // the pair: <keyA, keyB> for expressions of the form dimension * symbol,
1054 // where keyA is the position number of the dimension and keyB is the
1055 // position number of the symbol.
1056 lhsPos = lhs.cast<AffineDimExpr>().getPosition();
1057 rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1058 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1059 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1060 } else {
1061 // For product/modulo/division expressions having both lhs and rhs as
1062 // symbol, we design indices as a pair: <keyA, keyB> for expressions
1063 // of the form dimension * symbol, where keyA is the position number of
1064 // the dimension and keyB is the position number of the symbol.
1065 lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
1066 rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
1067 std::pair<unsigned, signed> indexEntry(lhsPos, rhsPos);
1068 addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
1069 }
1070 addedToMap[it.index()] = true;
1071 }
1072
1073 // Constructing the simplified semi-affine sum of product/division/mod
1074 // expression from the flattened form in the desired sorted order of indices
1075 // of the various individual product/division/mod expressions.
1076 llvm::sort(indices);
1077 for (const std::pair<unsigned, unsigned> index : indices) {
1078 assert(indexToExprMap.lookup(index) &&
1079 "cannot find key in `indexToExprMap` map");
1080 expr = expr + indexToExprMap.lookup(index) * coefficients.lookup(index);
1081 }
1082
1083 // Local identifiers.
1084 for (unsigned j = numDims + numSymbols, e = flatExprs.size() - 1; j < e;
1085 j++) {
1086 // If the coefficient of the local expression is 0, continue as we need not
1087 // add it in out final expression.
1088 if (flatExprs[j] == 0 || addedToMap[j - numDims - numSymbols])
1089 continue;
1090 auto term = localExprs[j - numDims - numSymbols] * flatExprs[j];
1091 expr = expr + term;
1092 }
1093
1094 // Constant term.
1095 int64_t constTerm = flatExprs.back();
1096 if (constTerm != 0)
1097 expr = expr + constTerm;
1098 return expr;
1099 }
1100
SimpleAffineExprFlattener(unsigned numDims,unsigned numSymbols)1101 SimpleAffineExprFlattener::SimpleAffineExprFlattener(unsigned numDims,
1102 unsigned numSymbols)
1103 : numDims(numDims), numSymbols(numSymbols), numLocals(0) {
1104 operandExprStack.reserve(8);
1105 }
1106
1107 // In pure affine t = expr * c, we multiply each coefficient of lhs with c.
1108 //
1109 // In case of semi affine multiplication expressions, t = expr * symbolic_expr,
1110 // introduce a local variable p (= expr * symbolic_expr), and the affine
1111 // expression expr * symbolic_expr is added to `localExprs`.
visitMulExpr(AffineBinaryOpExpr expr)1112 void SimpleAffineExprFlattener::visitMulExpr(AffineBinaryOpExpr expr) {
1113 assert(operandExprStack.size() >= 2);
1114 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1115 operandExprStack.pop_back();
1116 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1117
1118 // Flatten semi-affine multiplication expressions by introducing a local
1119 // variable in place of the product; the affine expression
1120 // corresponding to the quantifier is added to `localExprs`.
1121 if (!expr.getRHS().isa<AffineConstantExpr>()) {
1122 MLIRContext *context = expr.getContext();
1123 AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
1124 localExprs, context);
1125 AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1126 localExprs, context);
1127 addLocalVariableSemiAffine(a * b, lhs, lhs.size());
1128 return;
1129 }
1130
1131 // Get the RHS constant.
1132 auto rhsConst = rhs[getConstantIndex()];
1133 for (unsigned i = 0, e = lhs.size(); i < e; i++) {
1134 lhs[i] *= rhsConst;
1135 }
1136 }
1137
visitAddExpr(AffineBinaryOpExpr expr)1138 void SimpleAffineExprFlattener::visitAddExpr(AffineBinaryOpExpr expr) {
1139 assert(operandExprStack.size() >= 2);
1140 const auto &rhs = operandExprStack.back();
1141 auto &lhs = operandExprStack[operandExprStack.size() - 2];
1142 assert(lhs.size() == rhs.size());
1143 // Update the LHS in place.
1144 for (unsigned i = 0, e = rhs.size(); i < e; i++) {
1145 lhs[i] += rhs[i];
1146 }
1147 // Pop off the RHS.
1148 operandExprStack.pop_back();
1149 }
1150
1151 //
1152 // t = expr mod c <=> t = expr - c*q and c*q <= expr <= c*q + c - 1
1153 //
1154 // A mod expression "expr mod c" is thus flattened by introducing a new local
1155 // variable q (= expr floordiv c), such that expr mod c is replaced with
1156 // 'expr - c * q' and c * q <= expr <= c * q + c - 1 are added to localVarCst.
1157 //
1158 // In case of semi-affine modulo expressions, t = expr mod symbolic_expr,
1159 // introduce a local variable m (= expr mod symbolic_expr), and the affine
1160 // expression expr mod symbolic_expr is added to `localExprs`.
visitModExpr(AffineBinaryOpExpr expr)1161 void SimpleAffineExprFlattener::visitModExpr(AffineBinaryOpExpr expr) {
1162 assert(operandExprStack.size() >= 2);
1163
1164 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1165 operandExprStack.pop_back();
1166 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1167 MLIRContext *context = expr.getContext();
1168
1169 // Flatten semi affine modulo expressions by introducing a local
1170 // variable in place of the modulo value, and the affine expression
1171 // corresponding to the quantifier is added to `localExprs`.
1172 if (!expr.getRHS().isa<AffineConstantExpr>()) {
1173 AffineExpr dividendExpr = getAffineExprFromFlatForm(
1174 lhs, numDims, numSymbols, localExprs, context);
1175 AffineExpr divisorExpr = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1176 localExprs, context);
1177 AffineExpr modExpr = dividendExpr % divisorExpr;
1178 addLocalVariableSemiAffine(modExpr, lhs, lhs.size());
1179 return;
1180 }
1181
1182 int64_t rhsConst = rhs[getConstantIndex()];
1183 // TODO: handle modulo by zero case when this issue is fixed
1184 // at the other places in the IR.
1185 assert(rhsConst > 0 && "RHS constant has to be positive");
1186
1187 // Check if the LHS expression is a multiple of modulo factor.
1188 unsigned i, e;
1189 for (i = 0, e = lhs.size(); i < e; i++)
1190 if (lhs[i] % rhsConst != 0)
1191 break;
1192 // If yes, modulo expression here simplifies to zero.
1193 if (i == lhs.size()) {
1194 std::fill(lhs.begin(), lhs.end(), 0);
1195 return;
1196 }
1197
1198 // Add a local variable for the quotient, i.e., expr % c is replaced by
1199 // (expr - q * c) where q = expr floordiv c. Do this while canceling out
1200 // the GCD of expr and c.
1201 SmallVector<int64_t, 8> floorDividend(lhs);
1202 uint64_t gcd = rhsConst;
1203 for (unsigned i = 0, e = lhs.size(); i < e; i++)
1204 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1205 // Simplify the numerator and the denominator.
1206 if (gcd != 1) {
1207 for (unsigned i = 0, e = floorDividend.size(); i < e; i++)
1208 floorDividend[i] = floorDividend[i] / static_cast<int64_t>(gcd);
1209 }
1210 int64_t floorDivisor = rhsConst / static_cast<int64_t>(gcd);
1211
1212 // Construct the AffineExpr form of the floordiv to store in localExprs.
1213
1214 AffineExpr dividendExpr = getAffineExprFromFlatForm(
1215 floorDividend, numDims, numSymbols, localExprs, context);
1216 AffineExpr divisorExpr = getAffineConstantExpr(floorDivisor, context);
1217 AffineExpr floorDivExpr = dividendExpr.floorDiv(divisorExpr);
1218 int loc;
1219 if ((loc = findLocalId(floorDivExpr)) == -1) {
1220 addLocalFloorDivId(floorDividend, floorDivisor, floorDivExpr);
1221 // Set result at top of stack to "lhs - rhsConst * q".
1222 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
1223 } else {
1224 // Reuse the existing local id.
1225 lhs[getLocalVarStartIndex() + loc] = -rhsConst;
1226 }
1227 }
1228
visitCeilDivExpr(AffineBinaryOpExpr expr)1229 void SimpleAffineExprFlattener::visitCeilDivExpr(AffineBinaryOpExpr expr) {
1230 visitDivExpr(expr, /*isCeil=*/true);
1231 }
visitFloorDivExpr(AffineBinaryOpExpr expr)1232 void SimpleAffineExprFlattener::visitFloorDivExpr(AffineBinaryOpExpr expr) {
1233 visitDivExpr(expr, /*isCeil=*/false);
1234 }
1235
visitDimExpr(AffineDimExpr expr)1236 void SimpleAffineExprFlattener::visitDimExpr(AffineDimExpr expr) {
1237 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1238 auto &eq = operandExprStack.back();
1239 assert(expr.getPosition() < numDims && "Inconsistent number of dims");
1240 eq[getDimStartIndex() + expr.getPosition()] = 1;
1241 }
1242
visitSymbolExpr(AffineSymbolExpr expr)1243 void SimpleAffineExprFlattener::visitSymbolExpr(AffineSymbolExpr expr) {
1244 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1245 auto &eq = operandExprStack.back();
1246 assert(expr.getPosition() < numSymbols && "inconsistent number of symbols");
1247 eq[getSymbolStartIndex() + expr.getPosition()] = 1;
1248 }
1249
visitConstantExpr(AffineConstantExpr expr)1250 void SimpleAffineExprFlattener::visitConstantExpr(AffineConstantExpr expr) {
1251 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
1252 auto &eq = operandExprStack.back();
1253 eq[getConstantIndex()] = expr.getValue();
1254 }
1255
addLocalVariableSemiAffine(AffineExpr expr,SmallVectorImpl<int64_t> & result,unsigned long resultSize)1256 void SimpleAffineExprFlattener::addLocalVariableSemiAffine(
1257 AffineExpr expr, SmallVectorImpl<int64_t> &result,
1258 unsigned long resultSize) {
1259 assert(result.size() == resultSize &&
1260 "`result` vector passed is not of correct size");
1261 int loc;
1262 if ((loc = findLocalId(expr)) == -1)
1263 addLocalIdSemiAffine(expr);
1264 std::fill(result.begin(), result.end(), 0);
1265 if (loc == -1)
1266 result[getLocalVarStartIndex() + numLocals - 1] = 1;
1267 else
1268 result[getLocalVarStartIndex() + loc] = 1;
1269 }
1270
1271 // t = expr floordiv c <=> t = q, c * q <= expr <= c * q + c - 1
1272 // A floordiv is thus flattened by introducing a new local variable q, and
1273 // replacing that expression with 'q' while adding the constraints
1274 // c * q <= expr <= c * q + c - 1 to localVarCst (done by
1275 // FlatAffineConstraints::addLocalFloorDiv).
1276 //
1277 // A ceildiv is similarly flattened:
1278 // t = expr ceildiv c <=> t = (expr + c - 1) floordiv c
1279 //
1280 // In case of semi affine division expressions, t = expr floordiv symbolic_expr
1281 // or t = expr ceildiv symbolic_expr, introduce a local variable q (= expr
1282 // floordiv/ceildiv symbolic_expr), and the affine floordiv/ceildiv is added to
1283 // `localExprs`.
visitDivExpr(AffineBinaryOpExpr expr,bool isCeil)1284 void SimpleAffineExprFlattener::visitDivExpr(AffineBinaryOpExpr expr,
1285 bool isCeil) {
1286 assert(operandExprStack.size() >= 2);
1287
1288 MLIRContext *context = expr.getContext();
1289 SmallVector<int64_t, 8> rhs = operandExprStack.back();
1290 operandExprStack.pop_back();
1291 SmallVector<int64_t, 8> &lhs = operandExprStack.back();
1292
1293 // Flatten semi affine division expressions by introducing a local
1294 // variable in place of the quotient, and the affine expression corresponding
1295 // to the quantifier is added to `localExprs`.
1296 if (!expr.getRHS().isa<AffineConstantExpr>()) {
1297 AffineExpr a = getAffineExprFromFlatForm(lhs, numDims, numSymbols,
1298 localExprs, context);
1299 AffineExpr b = getAffineExprFromFlatForm(rhs, numDims, numSymbols,
1300 localExprs, context);
1301 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1302 addLocalVariableSemiAffine(divExpr, lhs, lhs.size());
1303 return;
1304 }
1305
1306 // This is a pure affine expr; the RHS is a positive constant.
1307 int64_t rhsConst = rhs[getConstantIndex()];
1308 // TODO: handle division by zero at the same time the issue is
1309 // fixed at other places.
1310 assert(rhsConst > 0 && "RHS constant has to be positive");
1311
1312 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
1313 // common divisors of the numerator and denominator.
1314 uint64_t gcd = std::abs(rhsConst);
1315 for (unsigned i = 0, e = lhs.size(); i < e; i++)
1316 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
1317 // Simplify the numerator and the denominator.
1318 if (gcd != 1) {
1319 for (unsigned i = 0, e = lhs.size(); i < e; i++)
1320 lhs[i] = lhs[i] / static_cast<int64_t>(gcd);
1321 }
1322 int64_t divisor = rhsConst / static_cast<int64_t>(gcd);
1323 // If the divisor becomes 1, the updated LHS is the result. (The
1324 // divisor can't be negative since rhsConst is positive).
1325 if (divisor == 1)
1326 return;
1327
1328 // If the divisor cannot be simplified to one, we will have to retain
1329 // the ceil/floor expr (simplified up until here). Add an existential
1330 // quantifier to express its result, i.e., expr1 div expr2 is replaced
1331 // by a new identifier, q.
1332 AffineExpr a =
1333 getAffineExprFromFlatForm(lhs, numDims, numSymbols, localExprs, context);
1334 AffineExpr b = getAffineConstantExpr(divisor, context);
1335
1336 int loc;
1337 AffineExpr divExpr = isCeil ? a.ceilDiv(b) : a.floorDiv(b);
1338 if ((loc = findLocalId(divExpr)) == -1) {
1339 if (!isCeil) {
1340 SmallVector<int64_t, 8> dividend(lhs);
1341 addLocalFloorDivId(dividend, divisor, divExpr);
1342 } else {
1343 // lhs ceildiv c <=> (lhs + c - 1) floordiv c
1344 SmallVector<int64_t, 8> dividend(lhs);
1345 dividend.back() += divisor - 1;
1346 addLocalFloorDivId(dividend, divisor, divExpr);
1347 }
1348 }
1349 // Set the expression on stack to the local var introduced to capture the
1350 // result of the division (floor or ceil).
1351 std::fill(lhs.begin(), lhs.end(), 0);
1352 if (loc == -1)
1353 lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
1354 else
1355 lhs[getLocalVarStartIndex() + loc] = 1;
1356 }
1357
1358 // Add a local identifier (needed to flatten a mod, floordiv, ceildiv expr).
1359 // The local identifier added is always a floordiv of a pure add/mul affine
1360 // function of other identifiers, coefficients of which are specified in
1361 // dividend and with respect to a positive constant divisor. localExpr is the
1362 // simplified tree expression (AffineExpr) corresponding to the quantifier.
addLocalFloorDivId(ArrayRef<int64_t> dividend,int64_t divisor,AffineExpr localExpr)1363 void SimpleAffineExprFlattener::addLocalFloorDivId(ArrayRef<int64_t> dividend,
1364 int64_t divisor,
1365 AffineExpr localExpr) {
1366 assert(divisor > 0 && "positive constant divisor expected");
1367 for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1368 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1369 localExprs.push_back(localExpr);
1370 numLocals++;
1371 // dividend and divisor are not used here; an override of this method uses it.
1372 }
1373
addLocalIdSemiAffine(AffineExpr localExpr)1374 void SimpleAffineExprFlattener::addLocalIdSemiAffine(AffineExpr localExpr) {
1375 for (SmallVector<int64_t, 8> &subExpr : operandExprStack)
1376 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
1377 localExprs.push_back(localExpr);
1378 ++numLocals;
1379 }
1380
findLocalId(AffineExpr localExpr)1381 int SimpleAffineExprFlattener::findLocalId(AffineExpr localExpr) {
1382 SmallVectorImpl<AffineExpr>::iterator it;
1383 if ((it = llvm::find(localExprs, localExpr)) == localExprs.end())
1384 return -1;
1385 return it - localExprs.begin();
1386 }
1387
1388 /// Simplify the affine expression by flattening it and reconstructing it.
simplifyAffineExpr(AffineExpr expr,unsigned numDims,unsigned numSymbols)1389 AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
1390 unsigned numSymbols) {
1391 // Simplify semi-affine expressions separately.
1392 if (!expr.isPureAffine())
1393 expr = simplifySemiAffine(expr);
1394
1395 SimpleAffineExprFlattener flattener(numDims, numSymbols);
1396 flattener.walkPostOrder(expr);
1397 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
1398 if (!expr.isPureAffine() &&
1399 expr == getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1400 flattener.localExprs,
1401 expr.getContext()))
1402 return expr;
1403 AffineExpr simplifiedExpr =
1404 expr.isPureAffine()
1405 ? getAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1406 flattener.localExprs, expr.getContext())
1407 : getSemiAffineExprFromFlatForm(flattenedExpr, numDims, numSymbols,
1408 flattener.localExprs,
1409 expr.getContext());
1410
1411 flattener.operandExprStack.pop_back();
1412 assert(flattener.operandExprStack.empty());
1413 return simplifiedExpr;
1414 }
1415