1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AffineMap.h"
10 #include "AffineMapDetail.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/Support/LogicalResult.h"
14 #include "mlir/Support/MathExtras.h"
15 #include "llvm/ADT/SmallBitVector.h"
16 #include "llvm/ADT/SmallSet.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Support/raw_ostream.h"
19 
20 using namespace mlir;
21 
22 namespace {
23 
24 // AffineExprConstantFolder evaluates an affine expression using constant
25 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute
26 // representing the constant value of the affine expression evaluated on
27 // constant 'operandConsts', or nullptr if it can't be folded.
28 class AffineExprConstantFolder {
29 public:
AffineExprConstantFolder(unsigned numDims,ArrayRef<Attribute> operandConsts)30   AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
31       : numDims(numDims), operandConsts(operandConsts) {}
32 
33   /// Attempt to constant fold the specified affine expr, or return null on
34   /// failure.
constantFold(AffineExpr expr)35   IntegerAttr constantFold(AffineExpr expr) {
36     if (auto result = constantFoldImpl(expr))
37       return IntegerAttr::get(IndexType::get(expr.getContext()), *result);
38     return nullptr;
39   }
40 
41 private:
constantFoldImpl(AffineExpr expr)42   Optional<int64_t> constantFoldImpl(AffineExpr expr) {
43     switch (expr.getKind()) {
44     case AffineExprKind::Add:
45       return constantFoldBinExpr(
46           expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
47     case AffineExprKind::Mul:
48       return constantFoldBinExpr(
49           expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
50     case AffineExprKind::Mod:
51       return constantFoldBinExpr(
52           expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
53     case AffineExprKind::FloorDiv:
54       return constantFoldBinExpr(
55           expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
56     case AffineExprKind::CeilDiv:
57       return constantFoldBinExpr(
58           expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
59     case AffineExprKind::Constant:
60       return expr.cast<AffineConstantExpr>().getValue();
61     case AffineExprKind::DimId:
62       if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()]
63                           .dyn_cast_or_null<IntegerAttr>())
64         return attr.getInt();
65       return llvm::None;
66     case AffineExprKind::SymbolId:
67       if (auto attr = operandConsts[numDims +
68                                     expr.cast<AffineSymbolExpr>().getPosition()]
69                           .dyn_cast_or_null<IntegerAttr>())
70         return attr.getInt();
71       return llvm::None;
72     }
73     llvm_unreachable("Unknown AffineExpr");
74   }
75 
76   // TODO: Change these to operate on APInts too.
constantFoldBinExpr(AffineExpr expr,int64_t (* op)(int64_t,int64_t))77   Optional<int64_t> constantFoldBinExpr(AffineExpr expr,
78                                         int64_t (*op)(int64_t, int64_t)) {
79     auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
80     if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
81       if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
82         return op(*lhs, *rhs);
83     return llvm::None;
84   }
85 
86   // The number of dimension operands in AffineMap containing this expression.
87   unsigned numDims;
88   // The constant valued operands used to evaluate this AffineExpr.
89   ArrayRef<Attribute> operandConsts;
90 };
91 
92 } // namespace
93 
94 /// Returns a single constant result affine map.
getConstantMap(int64_t val,MLIRContext * context)95 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
96   return get(/*dimCount=*/0, /*symbolCount=*/0,
97              {getAffineConstantExpr(val, context)});
98 }
99 
100 /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most
101 /// minor dimensions.
getMinorIdentityMap(unsigned dims,unsigned results,MLIRContext * context)102 AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
103                                          MLIRContext *context) {
104   assert(dims >= results && "Dimension mismatch");
105   auto id = AffineMap::getMultiDimIdentityMap(dims, context);
106   return AffineMap::get(dims, 0, id.getResults().take_back(results), context);
107 }
108 
isMinorIdentity() const109 bool AffineMap::isMinorIdentity() const {
110   return getNumDims() >= getNumResults() &&
111          *this ==
112              getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
113 }
114 
115 /// Returns true if this affine map is a minor identity up to broadcasted
116 /// dimensions which are indicated by value 0 in the result.
isMinorIdentityWithBroadcasting(SmallVectorImpl<unsigned> * broadcastedDims) const117 bool AffineMap::isMinorIdentityWithBroadcasting(
118     SmallVectorImpl<unsigned> *broadcastedDims) const {
119   if (broadcastedDims)
120     broadcastedDims->clear();
121   if (getNumDims() < getNumResults())
122     return false;
123   unsigned suffixStart = getNumDims() - getNumResults();
124   for (const auto &idxAndExpr : llvm::enumerate(getResults())) {
125     unsigned resIdx = idxAndExpr.index();
126     AffineExpr expr = idxAndExpr.value();
127     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
128       // Each result may be either a constant 0 (broadcasted dimension).
129       if (constExpr.getValue() != 0)
130         return false;
131       if (broadcastedDims)
132         broadcastedDims->push_back(resIdx);
133     } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
134       // Or it may be the input dimension corresponding to this result position.
135       if (dimExpr.getPosition() != suffixStart + resIdx)
136         return false;
137     } else {
138       return false;
139     }
140   }
141   return true;
142 }
143 
144 /// Return true if this affine map can be converted to a minor identity with
145 /// broadcast by doing a permute. Return a permutation (there may be
146 /// several) to apply to get to a minor identity with broadcasts.
147 /// Ex:
148 ///  * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with
149 ///  perm = [1, 0] and broadcast d2
150 ///  * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by
151 ///  permutation + broadcast
152 ///  * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3)
153 ///  with perm = [1, 0, 2] and broadcast d2
154 ///  * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra
155 ///  leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with
156 ///  perm = [3, 0, 1, 2]
isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl<unsigned> & permutedDims) const157 bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting(
158     SmallVectorImpl<unsigned> &permutedDims) const {
159   unsigned projectionStart =
160       getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0;
161   permutedDims.clear();
162   SmallVector<unsigned> broadcastDims;
163   permutedDims.resize(getNumResults(), 0);
164   // If there are more results than input dimensions we want the new map to
165   // start with broadcast dimensions in order to be a minor identity with
166   // broadcasting.
167   unsigned leadingBroadcast =
168       getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0;
169   llvm::SmallBitVector dimFound(std::max(getNumInputs(), getNumResults()),
170                                 false);
171   for (const auto &idxAndExpr : llvm::enumerate(getResults())) {
172     unsigned resIdx = idxAndExpr.index();
173     AffineExpr expr = idxAndExpr.value();
174     // Each result may be either a constant 0 (broadcast dimension) or a
175     // dimension.
176     if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
177       if (constExpr.getValue() != 0)
178         return false;
179       broadcastDims.push_back(resIdx);
180     } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
181       if (dimExpr.getPosition() < projectionStart)
182         return false;
183       unsigned newPosition =
184           dimExpr.getPosition() - projectionStart + leadingBroadcast;
185       permutedDims[resIdx] = newPosition;
186       dimFound[newPosition] = true;
187     } else {
188       return false;
189     }
190   }
191   // Find a permuation for the broadcast dimension. Since they are broadcasted
192   // any valid permutation is acceptable. We just permute the dim into a slot
193   // without an existing dimension.
194   unsigned pos = 0;
195   for (auto dim : broadcastDims) {
196     while (pos < dimFound.size() && dimFound[pos]) {
197       pos++;
198     }
199     permutedDims[dim] = pos++;
200   }
201   return true;
202 }
203 
204 /// Returns an AffineMap representing a permutation.
getPermutationMap(ArrayRef<unsigned> permutation,MLIRContext * context)205 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
206                                        MLIRContext *context) {
207   assert(!permutation.empty() &&
208          "Cannot create permutation map from empty permutation vector");
209   SmallVector<AffineExpr, 4> affExprs;
210   for (auto index : permutation)
211     affExprs.push_back(getAffineDimExpr(index, context));
212   const auto *m = std::max_element(permutation.begin(), permutation.end());
213   auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context);
214   assert(permutationMap.isPermutation() && "Invalid permutation vector");
215   return permutationMap;
216 }
217 
218 template <typename AffineExprContainer>
219 static SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<AffineExprContainer> exprsList)220 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
221   assert(!exprsList.empty());
222   assert(!exprsList[0].empty());
223   auto context = exprsList[0][0].getContext();
224   int64_t maxDim = -1, maxSym = -1;
225   getMaxDimAndSymbol(exprsList, maxDim, maxSym);
226   SmallVector<AffineMap, 4> maps;
227   maps.reserve(exprsList.size());
228   for (const auto &exprs : exprsList)
229     maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1,
230                                   /*symbolCount=*/maxSym + 1, exprs, context));
231   return maps;
232 }
233 
234 SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList)235 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) {
236   return ::inferFromExprList(exprsList);
237 }
238 
239 SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<SmallVector<AffineExpr,4>> exprsList)240 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) {
241   return ::inferFromExprList(exprsList);
242 }
243 
getMultiDimIdentityMap(unsigned numDims,MLIRContext * context)244 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
245                                             MLIRContext *context) {
246   SmallVector<AffineExpr, 4> dimExprs;
247   dimExprs.reserve(numDims);
248   for (unsigned i = 0; i < numDims; ++i)
249     dimExprs.push_back(mlir::getAffineDimExpr(i, context));
250   return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context);
251 }
252 
getContext() const253 MLIRContext *AffineMap::getContext() const { return map->context; }
254 
isIdentity() const255 bool AffineMap::isIdentity() const {
256   if (getNumDims() != getNumResults())
257     return false;
258   ArrayRef<AffineExpr> results = getResults();
259   for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
260     auto expr = results[i].dyn_cast<AffineDimExpr>();
261     if (!expr || expr.getPosition() != i)
262       return false;
263   }
264   return true;
265 }
266 
isEmpty() const267 bool AffineMap::isEmpty() const {
268   return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0;
269 }
270 
isSingleConstant() const271 bool AffineMap::isSingleConstant() const {
272   return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
273 }
274 
isConstant() const275 bool AffineMap::isConstant() const {
276   return llvm::all_of(getResults(), [](AffineExpr expr) {
277     return expr.isa<AffineConstantExpr>();
278   });
279 }
280 
getSingleConstantResult() const281 int64_t AffineMap::getSingleConstantResult() const {
282   assert(isSingleConstant() && "map must have a single constant result");
283   return getResult(0).cast<AffineConstantExpr>().getValue();
284 }
285 
getConstantResults() const286 SmallVector<int64_t> AffineMap::getConstantResults() const {
287   assert(isConstant() && "map must have only constant results");
288   SmallVector<int64_t> result;
289   for (auto expr : getResults())
290     result.emplace_back(expr.cast<AffineConstantExpr>().getValue());
291   return result;
292 }
293 
getNumDims() const294 unsigned AffineMap::getNumDims() const {
295   assert(map && "uninitialized map storage");
296   return map->numDims;
297 }
getNumSymbols() const298 unsigned AffineMap::getNumSymbols() const {
299   assert(map && "uninitialized map storage");
300   return map->numSymbols;
301 }
getNumResults() const302 unsigned AffineMap::getNumResults() const { return getResults().size(); }
getNumInputs() const303 unsigned AffineMap::getNumInputs() const {
304   assert(map && "uninitialized map storage");
305   return map->numDims + map->numSymbols;
306 }
getResults() const307 ArrayRef<AffineExpr> AffineMap::getResults() const {
308   assert(map && "uninitialized map storage");
309   return map->results();
310 }
getResult(unsigned idx) const311 AffineExpr AffineMap::getResult(unsigned idx) const {
312   return getResults()[idx];
313 }
314 
getDimPosition(unsigned idx) const315 unsigned AffineMap::getDimPosition(unsigned idx) const {
316   return getResult(idx).cast<AffineDimExpr>().getPosition();
317 }
318 
getPermutedPosition(unsigned input) const319 unsigned AffineMap::getPermutedPosition(unsigned input) const {
320   assert(isPermutation() && "invalid permutation request");
321   for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++)
322     if (getDimPosition(i) == input)
323       return i;
324   llvm_unreachable("incorrect permutation request");
325 }
326 
327 /// Folds the results of the application of an affine map on the provided
328 /// operands to a constant if possible. Returns false if the folding happens,
329 /// true otherwise.
330 LogicalResult
constantFold(ArrayRef<Attribute> operandConstants,SmallVectorImpl<Attribute> & results) const331 AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
332                         SmallVectorImpl<Attribute> &results) const {
333   // Attempt partial folding.
334   SmallVector<int64_t, 2> integers;
335   partialConstantFold(operandConstants, &integers);
336 
337   // If all expressions folded to a constant, populate results with attributes
338   // containing those constants.
339   if (integers.empty())
340     return failure();
341 
342   auto range = llvm::map_range(integers, [this](int64_t i) {
343     return IntegerAttr::get(IndexType::get(getContext()), i);
344   });
345   results.append(range.begin(), range.end());
346   return success();
347 }
348 
349 AffineMap
partialConstantFold(ArrayRef<Attribute> operandConstants,SmallVectorImpl<int64_t> * results) const350 AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
351                                SmallVectorImpl<int64_t> *results) const {
352   assert(getNumInputs() == operandConstants.size());
353 
354   // Fold each of the result expressions.
355   AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
356   SmallVector<AffineExpr, 4> exprs;
357   exprs.reserve(getNumResults());
358 
359   for (auto expr : getResults()) {
360     auto folded = exprFolder.constantFold(expr);
361     // If did not fold to a constant, keep the original expression, and clear
362     // the integer results vector.
363     if (folded) {
364       exprs.push_back(
365           getAffineConstantExpr(folded.getInt(), folded.getContext()));
366       if (results)
367         results->push_back(folded.getInt());
368     } else {
369       exprs.push_back(expr);
370       if (results) {
371         results->clear();
372         results = nullptr;
373       }
374     }
375   }
376 
377   return get(getNumDims(), getNumSymbols(), exprs, getContext());
378 }
379 
380 /// Walk all of the AffineExpr's in this mapping. Each node in an expression
381 /// tree is visited in postorder.
walkExprs(llvm::function_ref<void (AffineExpr)> callback) const382 void AffineMap::walkExprs(llvm::function_ref<void(AffineExpr)> callback) const {
383   for (auto expr : getResults())
384     expr.walk(callback);
385 }
386 
387 /// This method substitutes any uses of dimensions and symbols (e.g.
388 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
389 /// expression mapping.  Because this can be used to eliminate dims and
390 /// symbols, the client needs to specify the number of dims and symbols in
391 /// the result.  The returned map always has the same number of results.
replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,ArrayRef<AffineExpr> symReplacements,unsigned numResultDims,unsigned numResultSyms) const392 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
393                                            ArrayRef<AffineExpr> symReplacements,
394                                            unsigned numResultDims,
395                                            unsigned numResultSyms) const {
396   SmallVector<AffineExpr, 8> results;
397   results.reserve(getNumResults());
398   for (auto expr : getResults())
399     results.push_back(
400         expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
401   return get(numResultDims, numResultSyms, results, getContext());
402 }
403 
404 /// Sparse replace method. Apply AffineExpr::replace(`expr`, `replacement`) to
405 /// each of the results and return a new AffineMap with the new results and
406 /// with the specified number of dims and symbols.
replace(AffineExpr expr,AffineExpr replacement,unsigned numResultDims,unsigned numResultSyms) const407 AffineMap AffineMap::replace(AffineExpr expr, AffineExpr replacement,
408                              unsigned numResultDims,
409                              unsigned numResultSyms) const {
410   SmallVector<AffineExpr, 4> newResults;
411   newResults.reserve(getNumResults());
412   for (AffineExpr e : getResults())
413     newResults.push_back(e.replace(expr, replacement));
414   return AffineMap::get(numResultDims, numResultSyms, newResults, getContext());
415 }
416 
417 /// Sparse replace method. Apply AffineExpr::replace(`map`) to each of the
418 /// results and return a new AffineMap with the new results and with the
419 /// specified number of dims and symbols.
replace(const DenseMap<AffineExpr,AffineExpr> & map,unsigned numResultDims,unsigned numResultSyms) const420 AffineMap AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map,
421                              unsigned numResultDims,
422                              unsigned numResultSyms) const {
423   SmallVector<AffineExpr, 4> newResults;
424   newResults.reserve(getNumResults());
425   for (AffineExpr e : getResults())
426     newResults.push_back(e.replace(map));
427   return AffineMap::get(numResultDims, numResultSyms, newResults, getContext());
428 }
429 
430 AffineMap
replace(const DenseMap<AffineExpr,AffineExpr> & map) const431 AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
432   SmallVector<AffineExpr, 4> newResults;
433   newResults.reserve(getNumResults());
434   for (AffineExpr e : getResults())
435     newResults.push_back(e.replace(map));
436   return AffineMap::inferFromExprList(newResults).front();
437 }
438 
compose(AffineMap map) const439 AffineMap AffineMap::compose(AffineMap map) const {
440   assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
441   // Prepare `map` by concatenating the symbols and rewriting its exprs.
442   unsigned numDims = map.getNumDims();
443   unsigned numSymbolsThisMap = getNumSymbols();
444   unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols();
445   SmallVector<AffineExpr, 8> newDims(numDims);
446   for (unsigned idx = 0; idx < numDims; ++idx) {
447     newDims[idx] = getAffineDimExpr(idx, getContext());
448   }
449   SmallVector<AffineExpr, 8> newSymbols(numSymbols - numSymbolsThisMap);
450   for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
451     newSymbols[idx - numSymbolsThisMap] =
452         getAffineSymbolExpr(idx, getContext());
453   }
454   auto newMap =
455       map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols);
456   SmallVector<AffineExpr, 8> exprs;
457   exprs.reserve(getResults().size());
458   for (auto expr : getResults())
459     exprs.push_back(expr.compose(newMap));
460   return AffineMap::get(numDims, numSymbols, exprs, map.getContext());
461 }
462 
compose(ArrayRef<int64_t> values) const463 SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
464   assert(getNumSymbols() == 0 && "Expected symbol-less map");
465   SmallVector<AffineExpr, 4> exprs;
466   exprs.reserve(values.size());
467   MLIRContext *ctx = getContext();
468   for (auto v : values)
469     exprs.push_back(getAffineConstantExpr(v, ctx));
470   auto resMap = compose(AffineMap::get(0, 0, exprs, ctx));
471   SmallVector<int64_t, 4> res;
472   res.reserve(resMap.getNumResults());
473   for (auto e : resMap.getResults())
474     res.push_back(e.cast<AffineConstantExpr>().getValue());
475   return res;
476 }
477 
isProjectedPermutation(bool allowZeroInResults) const478 bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
479   if (getNumSymbols() > 0)
480     return false;
481 
482   // Having more results than inputs means that results have duplicated dims or
483   // zeros that can't be mapped to input dims.
484   if (getNumResults() > getNumInputs())
485     return false;
486 
487   SmallVector<bool, 8> seen(getNumInputs(), false);
488   // A projected permutation can have, at most, only one instance of each input
489   // dimension in the result expressions. Zeros are allowed as long as the
490   // number of result expressions is lower or equal than the number of input
491   // expressions.
492   for (auto expr : getResults()) {
493     if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
494       if (seen[dim.getPosition()])
495         return false;
496       seen[dim.getPosition()] = true;
497     } else {
498       auto constExpr = expr.dyn_cast<AffineConstantExpr>();
499       if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0)
500         return false;
501     }
502   }
503 
504   // Results are either dims or zeros and zeros can be mapped to input dims.
505   return true;
506 }
507 
isPermutation() const508 bool AffineMap::isPermutation() const {
509   if (getNumDims() != getNumResults())
510     return false;
511   return isProjectedPermutation();
512 }
513 
getSubMap(ArrayRef<unsigned> resultPos) const514 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const {
515   SmallVector<AffineExpr, 4> exprs;
516   exprs.reserve(resultPos.size());
517   for (auto idx : resultPos)
518     exprs.push_back(getResult(idx));
519   return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
520 }
521 
getSliceMap(unsigned start,unsigned length) const522 AffineMap AffineMap::getSliceMap(unsigned start, unsigned length) const {
523   return AffineMap::get(getNumDims(), getNumSymbols(),
524                         getResults().slice(start, length), getContext());
525 }
526 
getMajorSubMap(unsigned numResults) const527 AffineMap AffineMap::getMajorSubMap(unsigned numResults) const {
528   if (numResults == 0)
529     return AffineMap();
530   if (numResults > getNumResults())
531     return *this;
532   return getSliceMap(0, numResults);
533 }
534 
getMinorSubMap(unsigned numResults) const535 AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
536   if (numResults == 0)
537     return AffineMap();
538   if (numResults > getNumResults())
539     return *this;
540   return getSliceMap(getNumResults() - numResults, numResults);
541 }
542 
compressDims(AffineMap map,const llvm::SmallBitVector & unusedDims)543 AffineMap mlir::compressDims(AffineMap map,
544                              const llvm::SmallBitVector &unusedDims) {
545   unsigned numDims = 0;
546   SmallVector<AffineExpr> dimReplacements;
547   dimReplacements.reserve(map.getNumDims());
548   MLIRContext *context = map.getContext();
549   for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) {
550     if (unusedDims.test(dim))
551       dimReplacements.push_back(getAffineConstantExpr(0, context));
552     else
553       dimReplacements.push_back(getAffineDimExpr(numDims++, context));
554   }
555   SmallVector<AffineExpr> resultExprs;
556   resultExprs.reserve(map.getNumResults());
557   for (auto e : map.getResults())
558     resultExprs.push_back(e.replaceDims(dimReplacements));
559   return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context);
560 }
561 
compressUnusedDims(AffineMap map)562 AffineMap mlir::compressUnusedDims(AffineMap map) {
563   return compressDims(map, getUnusedDimsBitVector({map}));
564 }
565 
566 static SmallVector<AffineMap>
compressUnusedImpl(ArrayRef<AffineMap> maps,llvm::function_ref<AffineMap (AffineMap)> compressionFun)567 compressUnusedImpl(ArrayRef<AffineMap> maps,
568                    llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
569   if (maps.empty())
570     return SmallVector<AffineMap>();
571   SmallVector<AffineExpr> allExprs;
572   allExprs.reserve(maps.size() * maps.front().getNumResults());
573   unsigned numDims = maps.front().getNumDims(),
574            numSymbols = maps.front().getNumSymbols();
575   for (auto m : maps) {
576     assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() &&
577            "expected maps with same num dims and symbols");
578     llvm::append_range(allExprs, m.getResults());
579   }
580   AffineMap unifiedMap = compressionFun(
581       AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext()));
582   unsigned unifiedNumDims = unifiedMap.getNumDims(),
583            unifiedNumSymbols = unifiedMap.getNumSymbols();
584   ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults();
585   SmallVector<AffineMap> res;
586   res.reserve(maps.size());
587   for (auto m : maps) {
588     res.push_back(AffineMap::get(unifiedNumDims, unifiedNumSymbols,
589                                  unifiedResults.take_front(m.getNumResults()),
590                                  m.getContext()));
591     unifiedResults = unifiedResults.drop_front(m.getNumResults());
592   }
593   return res;
594 }
595 
compressUnusedDims(ArrayRef<AffineMap> maps)596 SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
597   return compressUnusedImpl(maps,
598                             [](AffineMap m) { return compressUnusedDims(m); });
599 }
600 
compressSymbols(AffineMap map,const llvm::SmallBitVector & unusedSymbols)601 AffineMap mlir::compressSymbols(AffineMap map,
602                                 const llvm::SmallBitVector &unusedSymbols) {
603   unsigned numSymbols = 0;
604   SmallVector<AffineExpr> symReplacements;
605   symReplacements.reserve(map.getNumSymbols());
606   MLIRContext *context = map.getContext();
607   for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) {
608     if (unusedSymbols.test(sym))
609       symReplacements.push_back(getAffineConstantExpr(0, context));
610     else
611       symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context));
612   }
613   SmallVector<AffineExpr> resultExprs;
614   resultExprs.reserve(map.getNumResults());
615   for (auto e : map.getResults())
616     resultExprs.push_back(e.replaceSymbols(symReplacements));
617   return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context);
618 }
619 
compressUnusedSymbols(AffineMap map)620 AffineMap mlir::compressUnusedSymbols(AffineMap map) {
621   llvm::SmallBitVector unusedSymbols(map.getNumSymbols(), true);
622   map.walkExprs([&](AffineExpr expr) {
623     if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
624       unusedSymbols.reset(symExpr.getPosition());
625   });
626   return compressSymbols(map, unusedSymbols);
627 }
628 
compressUnusedSymbols(ArrayRef<AffineMap> maps)629 SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) {
630   return compressUnusedImpl(
631       maps, [](AffineMap m) { return compressUnusedSymbols(m); });
632 }
633 
simplifyAffineMap(AffineMap map)634 AffineMap mlir::simplifyAffineMap(AffineMap map) {
635   SmallVector<AffineExpr, 8> exprs;
636   for (auto e : map.getResults()) {
637     exprs.push_back(
638         simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols()));
639   }
640   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs,
641                         map.getContext());
642 }
643 
removeDuplicateExprs(AffineMap map)644 AffineMap mlir::removeDuplicateExprs(AffineMap map) {
645   auto results = map.getResults();
646   SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end());
647   uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()),
648                     uniqueExprs.end());
649   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs,
650                         map.getContext());
651 }
652 
inversePermutation(AffineMap map)653 AffineMap mlir::inversePermutation(AffineMap map) {
654   if (map.isEmpty())
655     return map;
656   assert(map.getNumSymbols() == 0 && "expected map without symbols");
657   SmallVector<AffineExpr, 4> exprs(map.getNumDims());
658   for (const auto &en : llvm::enumerate(map.getResults())) {
659     auto expr = en.value();
660     // Skip non-permutations.
661     if (auto d = expr.dyn_cast<AffineDimExpr>()) {
662       if (exprs[d.getPosition()])
663         continue;
664       exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
665     }
666   }
667   SmallVector<AffineExpr, 4> seenExprs;
668   seenExprs.reserve(map.getNumDims());
669   for (auto expr : exprs)
670     if (expr)
671       seenExprs.push_back(expr);
672   if (seenExprs.size() != map.getNumInputs())
673     return AffineMap();
674   return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
675 }
676 
inverseAndBroadcastProjectedPermutation(AffineMap map)677 AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) {
678   assert(map.isProjectedPermutation(/*allowZeroInResults=*/true));
679   MLIRContext *context = map.getContext();
680   AffineExpr zero = mlir::getAffineConstantExpr(0, context);
681   // Start with all the results as 0.
682   SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero);
683   for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
684     // Skip zeros from input map. 'exprs' is already initialized to zero.
685     if (auto constExpr = map.getResult(i).dyn_cast<AffineConstantExpr>()) {
686       assert(constExpr.getValue() == 0 &&
687              "Unexpected constant in projected permutation");
688       (void)constExpr;
689       continue;
690     }
691 
692     // Reverse each dimension existing in the original map result.
693     exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context);
694   }
695   return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context);
696 }
697 
concatAffineMaps(ArrayRef<AffineMap> maps)698 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
699   unsigned numResults = 0, numDims = 0, numSymbols = 0;
700   for (auto m : maps)
701     numResults += m.getNumResults();
702   SmallVector<AffineExpr, 8> results;
703   results.reserve(numResults);
704   for (auto m : maps) {
705     for (auto res : m.getResults())
706       results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols));
707 
708     numSymbols += m.getNumSymbols();
709     numDims = std::max(m.getNumDims(), numDims);
710   }
711   return AffineMap::get(numDims, numSymbols, results,
712                         maps.front().getContext());
713 }
714 
getProjectedMap(AffineMap map,const llvm::SmallBitVector & unusedDims)715 AffineMap mlir::getProjectedMap(AffineMap map,
716                                 const llvm::SmallBitVector &unusedDims) {
717   return compressUnusedSymbols(compressDims(map, unusedDims));
718 }
719 
getUnusedDimsBitVector(ArrayRef<AffineMap> maps)720 llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
721   unsigned numDims = maps[0].getNumDims();
722   llvm::SmallBitVector numDimsBitVector(numDims, true);
723   for (const auto &m : maps) {
724     for (unsigned i = 0; i < numDims; ++i) {
725       if (m.isFunctionOfDim(i))
726         numDimsBitVector.reset(i);
727     }
728   }
729   return numDimsBitVector;
730 }
731 
732 //===----------------------------------------------------------------------===//
733 // MutableAffineMap.
734 //===----------------------------------------------------------------------===//
735 
MutableAffineMap(AffineMap map)736 MutableAffineMap::MutableAffineMap(AffineMap map)
737     : results(map.getResults().begin(), map.getResults().end()),
738       numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
739       context(map.getContext()) {}
740 
reset(AffineMap map)741 void MutableAffineMap::reset(AffineMap map) {
742   results.clear();
743   numDims = map.getNumDims();
744   numSymbols = map.getNumSymbols();
745   context = map.getContext();
746   llvm::append_range(results, map.getResults());
747 }
748 
isMultipleOf(unsigned idx,int64_t factor) const749 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
750   if (results[idx].isMultipleOf(factor))
751     return true;
752 
753   // TODO: use simplifyAffineExpr and FlatAffineConstraints to
754   // complete this (for a more powerful analysis).
755   return false;
756 }
757 
758 // Simplifies the result affine expressions of this map. The expressions have to
759 // be pure for the simplification implemented.
simplify()760 void MutableAffineMap::simplify() {
761   // Simplify each of the results if possible.
762   // TODO: functional-style map
763   for (unsigned i = 0, e = getNumResults(); i < e; i++) {
764     results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols);
765   }
766 }
767 
getAffineMap() const768 AffineMap MutableAffineMap::getAffineMap() const {
769   return AffineMap::get(numDims, numSymbols, results, context);
770 }
771