1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/AffineMap.h"
10 #include "AffineMapDetail.h"
11 #include "mlir/IR/Attributes.h"
12 #include "mlir/IR/StandardTypes.h"
13 #include "mlir/Support/Functional.h"
14 #include "mlir/Support/LogicalResult.h"
15 #include "mlir/Support/MathExtras.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/Support/raw_ostream.h"
18 
19 using namespace mlir;
20 
21 namespace {
22 
23 // AffineExprConstantFolder evaluates an affine expression using constant
24 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute
25 // representing the constant value of the affine expression evaluated on
26 // constant 'operandConsts', or nullptr if it can't be folded.
27 class AffineExprConstantFolder {
28 public:
29   AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
30       : numDims(numDims), operandConsts(operandConsts) {}
31 
32   /// Attempt to constant fold the specified affine expr, or return null on
33   /// failure.
34   IntegerAttr constantFold(AffineExpr expr) {
35     if (auto result = constantFoldImpl(expr))
36       return IntegerAttr::get(IndexType::get(expr.getContext()), *result);
37     return nullptr;
38   }
39 
40 private:
41   Optional<int64_t> constantFoldImpl(AffineExpr expr) {
42     switch (expr.getKind()) {
43     case AffineExprKind::Add:
44       return constantFoldBinExpr(
45           expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
46     case AffineExprKind::Mul:
47       return constantFoldBinExpr(
48           expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
49     case AffineExprKind::Mod:
50       return constantFoldBinExpr(
51           expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
52     case AffineExprKind::FloorDiv:
53       return constantFoldBinExpr(
54           expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
55     case AffineExprKind::CeilDiv:
56       return constantFoldBinExpr(
57           expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
58     case AffineExprKind::Constant:
59       return expr.cast<AffineConstantExpr>().getValue();
60     case AffineExprKind::DimId:
61       if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()]
62                           .dyn_cast_or_null<IntegerAttr>())
63         return attr.getInt();
64       return llvm::None;
65     case AffineExprKind::SymbolId:
66       if (auto attr = operandConsts[numDims +
67                                     expr.cast<AffineSymbolExpr>().getPosition()]
68                           .dyn_cast_or_null<IntegerAttr>())
69         return attr.getInt();
70       return llvm::None;
71     }
72     llvm_unreachable("Unknown AffineExpr");
73   }
74 
75   // TODO: Change these to operate on APInts too.
76   Optional<int64_t> constantFoldBinExpr(AffineExpr expr,
77                                         int64_t (*op)(int64_t, int64_t)) {
78     auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
79     if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
80       if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
81         return op(*lhs, *rhs);
82     return llvm::None;
83   }
84 
85   // The number of dimension operands in AffineMap containing this expression.
86   unsigned numDims;
87   // The constant valued operands used to evaluate this AffineExpr.
88   ArrayRef<Attribute> operandConsts;
89 };
90 
91 } // end anonymous namespace
92 
93 /// Returns a single constant result affine map.
94 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
95   return get(/*dimCount=*/0, /*symbolCount=*/0,
96              {getAffineConstantExpr(val, context)});
97 }
98 
99 /// Returns an AffineMap representing a permutation.
100 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
101                                        MLIRContext *context) {
102   assert(!permutation.empty() &&
103          "Cannot create permutation map from empty permutation vector");
104   SmallVector<AffineExpr, 4> affExprs;
105   for (auto index : permutation)
106     affExprs.push_back(getAffineDimExpr(index, context));
107   auto m = std::max_element(permutation.begin(), permutation.end());
108   auto permutationMap = AffineMap::get(*m + 1, 0, affExprs);
109   assert(permutationMap.isPermutation() && "Invalid permutation vector");
110   return permutationMap;
111 }
112 
113 template <typename AffineExprContainer>
114 static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
115                                int64_t &maxDim, int64_t &maxSym) {
116   for (const auto &exprs : exprsList) {
117     for (auto expr : exprs) {
118       expr.walk([&maxDim, &maxSym](AffineExpr e) {
119         if (auto d = e.dyn_cast<AffineDimExpr>())
120           maxDim = std::max(maxDim, static_cast<int64_t>(d.getPosition()));
121         if (auto s = e.dyn_cast<AffineSymbolExpr>())
122           maxSym = std::max(maxSym, static_cast<int64_t>(s.getPosition()));
123       });
124     }
125   }
126 }
127 
128 template <typename AffineExprContainer>
129 static SmallVector<AffineMap, 4>
130 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
131   int64_t maxDim = -1, maxSym = -1;
132   getMaxDimAndSymbol(exprsList, maxDim, maxSym);
133   SmallVector<AffineMap, 4> maps;
134   maps.reserve(exprsList.size());
135   for (const auto &exprs : exprsList)
136     maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1,
137                                   /*symbolCount=*/maxSym + 1, exprs));
138   return maps;
139 }
140 
141 SmallVector<AffineMap, 4>
142 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) {
143   return ::inferFromExprList(exprsList);
144 }
145 
146 SmallVector<AffineMap, 4>
147 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) {
148   return ::inferFromExprList(exprsList);
149 }
150 
151 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
152                                             MLIRContext *context) {
153   SmallVector<AffineExpr, 4> dimExprs;
154   dimExprs.reserve(numDims);
155   for (unsigned i = 0; i < numDims; ++i)
156     dimExprs.push_back(mlir::getAffineDimExpr(i, context));
157   return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs);
158 }
159 
160 MLIRContext *AffineMap::getContext() const { return map->context; }
161 
162 bool AffineMap::isIdentity() const {
163   if (getNumDims() != getNumResults())
164     return false;
165   ArrayRef<AffineExpr> results = getResults();
166   for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
167     auto expr = results[i].dyn_cast<AffineDimExpr>();
168     if (!expr || expr.getPosition() != i)
169       return false;
170   }
171   return true;
172 }
173 
174 bool AffineMap::isEmpty() const {
175   return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0;
176 }
177 
178 bool AffineMap::isSingleConstant() const {
179   return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
180 }
181 
182 int64_t AffineMap::getSingleConstantResult() const {
183   assert(isSingleConstant() && "map must have a single constant result");
184   return getResult(0).cast<AffineConstantExpr>().getValue();
185 }
186 
187 unsigned AffineMap::getNumDims() const {
188   assert(map && "uninitialized map storage");
189   return map->numDims;
190 }
191 unsigned AffineMap::getNumSymbols() const {
192   assert(map && "uninitialized map storage");
193   return map->numSymbols;
194 }
195 unsigned AffineMap::getNumResults() const {
196   assert(map && "uninitialized map storage");
197   return map->results.size();
198 }
199 unsigned AffineMap::getNumInputs() const {
200   assert(map && "uninitialized map storage");
201   return map->numDims + map->numSymbols;
202 }
203 
204 ArrayRef<AffineExpr> AffineMap::getResults() const {
205   assert(map && "uninitialized map storage");
206   return map->results;
207 }
208 AffineExpr AffineMap::getResult(unsigned idx) const {
209   assert(map && "uninitialized map storage");
210   return map->results[idx];
211 }
212 
213 /// Folds the results of the application of an affine map on the provided
214 /// operands to a constant if possible. Returns false if the folding happens,
215 /// true otherwise.
216 LogicalResult
217 AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
218                         SmallVectorImpl<Attribute> &results) const {
219   assert(getNumInputs() == operandConstants.size());
220 
221   // Fold each of the result expressions.
222   AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
223   // Constant fold each AffineExpr in AffineMap and add to 'results'.
224   for (auto expr : getResults()) {
225     auto folded = exprFolder.constantFold(expr);
226     // If we didn't fold to a constant, then folding fails.
227     if (!folded)
228       return failure();
229 
230     results.push_back(folded);
231   }
232   assert(results.size() == getNumResults() &&
233          "constant folding produced the wrong number of results");
234   return success();
235 }
236 
237 /// Walk all of the AffineExpr's in this mapping. Each node in an expression
238 /// tree is visited in postorder.
239 void AffineMap::walkExprs(std::function<void(AffineExpr)> callback) const {
240   for (auto expr : getResults())
241     expr.walk(callback);
242 }
243 
244 /// This method substitutes any uses of dimensions and symbols (e.g.
245 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
246 /// expression mapping.  Because this can be used to eliminate dims and
247 /// symbols, the client needs to specify the number of dims and symbols in
248 /// the result.  The returned map always has the same number of results.
249 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
250                                            ArrayRef<AffineExpr> symReplacements,
251                                            unsigned numResultDims,
252                                            unsigned numResultSyms) {
253   SmallVector<AffineExpr, 8> results;
254   results.reserve(getNumResults());
255   for (auto expr : getResults())
256     results.push_back(
257         expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
258 
259   return results.empty() ? get(numResultDims, 0, getContext())
260                          : get(numResultDims, numResultSyms, results);
261 }
262 
263 AffineMap AffineMap::compose(AffineMap map) {
264   assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
265   // Prepare `map` by concatenating the symbols and rewriting its exprs.
266   unsigned numDims = map.getNumDims();
267   unsigned numSymbolsThisMap = getNumSymbols();
268   unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols();
269   SmallVector<AffineExpr, 8> newDims(numDims);
270   for (unsigned idx = 0; idx < numDims; ++idx) {
271     newDims[idx] = getAffineDimExpr(idx, getContext());
272   }
273   SmallVector<AffineExpr, 8> newSymbols(numSymbols);
274   for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
275     newSymbols[idx - numSymbolsThisMap] =
276         getAffineSymbolExpr(idx, getContext());
277   }
278   auto newMap =
279       map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols);
280   SmallVector<AffineExpr, 8> exprs;
281   exprs.reserve(getResults().size());
282   for (auto expr : getResults())
283     exprs.push_back(expr.compose(newMap));
284   return exprs.empty() ? AffineMap::get(numDims, 0, map.getContext())
285                        : AffineMap::get(numDims, numSymbols, exprs);
286 }
287 
288 bool AffineMap::isProjectedPermutation() {
289   if (getNumSymbols() > 0)
290     return false;
291   SmallVector<bool, 8> seen(getNumInputs(), false);
292   for (auto expr : getResults()) {
293     if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
294       if (seen[dim.getPosition()])
295         return false;
296       seen[dim.getPosition()] = true;
297       continue;
298     }
299     return false;
300   }
301   return true;
302 }
303 
304 bool AffineMap::isPermutation() {
305   if (getNumDims() != getNumResults())
306     return false;
307   return isProjectedPermutation();
308 }
309 
310 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) {
311   SmallVector<AffineExpr, 4> exprs;
312   exprs.reserve(resultPos.size());
313   for (auto idx : resultPos) {
314     exprs.push_back(getResult(idx));
315   }
316   return AffineMap::get(getNumDims(), getNumSymbols(), exprs);
317 }
318 
319 AffineMap mlir::simplifyAffineMap(AffineMap map) {
320   SmallVector<AffineExpr, 8> exprs;
321   for (auto e : map.getResults()) {
322     exprs.push_back(
323         simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols()));
324   }
325   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs);
326 }
327 
328 AffineMap mlir::removeDuplicateExprs(AffineMap map) {
329   auto results = map.getResults();
330   SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end());
331   uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()),
332                     uniqueExprs.end());
333   return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs,
334                         map.getContext());
335 }
336 
337 AffineMap mlir::inversePermutation(AffineMap map) {
338   if (map.isEmpty())
339     return map;
340   assert(map.getNumSymbols() == 0 && "expected map without symbols");
341   SmallVector<AffineExpr, 4> exprs(map.getNumDims());
342   for (auto en : llvm::enumerate(map.getResults())) {
343     auto expr = en.value();
344     // Skip non-permutations.
345     if (auto d = expr.dyn_cast<AffineDimExpr>()) {
346       if (exprs[d.getPosition()])
347         continue;
348       exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
349     }
350   }
351   SmallVector<AffineExpr, 4> seenExprs;
352   seenExprs.reserve(map.getNumDims());
353   for (auto expr : exprs)
354     if (expr)
355       seenExprs.push_back(expr);
356   if (seenExprs.size() != map.getNumInputs())
357     return AffineMap();
358   return AffineMap::get(map.getNumResults(), 0, seenExprs);
359 }
360 
361 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
362   unsigned numResults = 0;
363   for (auto m : maps)
364     numResults += m.getNumResults();
365   unsigned numDims = 0;
366   SmallVector<AffineExpr, 8> results;
367   results.reserve(numResults);
368   for (auto m : maps) {
369     assert(m.getNumSymbols() == 0 && "expected map without symbols");
370     results.append(m.getResults().begin(), m.getResults().end());
371     numDims = std::max(m.getNumDims(), numDims);
372   }
373   return results.empty() ? AffineMap::get(numDims, /*numSymbols=*/0,
374                                           maps.front().getContext())
375                          : AffineMap::get(numDims, /*numSymbols=*/0, results);
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // MutableAffineMap.
380 //===----------------------------------------------------------------------===//
381 
382 MutableAffineMap::MutableAffineMap(AffineMap map)
383     : numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
384       // A map always has at least 1 result by construction
385       context(map.getResult(0).getContext()) {
386   for (auto result : map.getResults())
387     results.push_back(result);
388 }
389 
390 void MutableAffineMap::reset(AffineMap map) {
391   results.clear();
392   numDims = map.getNumDims();
393   numSymbols = map.getNumSymbols();
394   // A map always has at least 1 result by construction
395   context = map.getResult(0).getContext();
396   for (auto result : map.getResults())
397     results.push_back(result);
398 }
399 
400 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
401   if (results[idx].isMultipleOf(factor))
402     return true;
403 
404   // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to
405   // complete this (for a more powerful analysis).
406   return false;
407 }
408 
409 // Simplifies the result affine expressions of this map. The expressions have to
410 // be pure for the simplification implemented.
411 void MutableAffineMap::simplify() {
412   // Simplify each of the results if possible.
413   // TODO(ntv): functional-style map
414   for (unsigned i = 0, e = getNumResults(); i < e; i++) {
415     results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols);
416   }
417 }
418 
419 AffineMap MutableAffineMap::getAffineMap() const {
420   return AffineMap::get(numDims, numSymbols, results);
421 }
422