1 //===- AffineMap.cpp - MLIR Affine Map Classes ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/AffineMap.h"
10 #include "AffineMapDetail.h"
11 #include "mlir/IR/BuiltinAttributes.h"
12 #include "mlir/IR/BuiltinTypes.h"
13 #include "mlir/Support/LogicalResult.h"
14 #include "mlir/Support/MathExtras.h"
15 #include "llvm/ADT/SmallBitVector.h"
16 #include "llvm/ADT/SmallSet.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Support/raw_ostream.h"
19
20 using namespace mlir;
21
22 namespace {
23
24 // AffineExprConstantFolder evaluates an affine expression using constant
25 // operands passed in 'operandConsts'. Returns an IntegerAttr attribute
26 // representing the constant value of the affine expression evaluated on
27 // constant 'operandConsts', or nullptr if it can't be folded.
28 class AffineExprConstantFolder {
29 public:
AffineExprConstantFolder(unsigned numDims,ArrayRef<Attribute> operandConsts)30 AffineExprConstantFolder(unsigned numDims, ArrayRef<Attribute> operandConsts)
31 : numDims(numDims), operandConsts(operandConsts) {}
32
33 /// Attempt to constant fold the specified affine expr, or return null on
34 /// failure.
constantFold(AffineExpr expr)35 IntegerAttr constantFold(AffineExpr expr) {
36 if (auto result = constantFoldImpl(expr))
37 return IntegerAttr::get(IndexType::get(expr.getContext()), *result);
38 return nullptr;
39 }
40
41 private:
constantFoldImpl(AffineExpr expr)42 Optional<int64_t> constantFoldImpl(AffineExpr expr) {
43 switch (expr.getKind()) {
44 case AffineExprKind::Add:
45 return constantFoldBinExpr(
46 expr, [](int64_t lhs, int64_t rhs) { return lhs + rhs; });
47 case AffineExprKind::Mul:
48 return constantFoldBinExpr(
49 expr, [](int64_t lhs, int64_t rhs) { return lhs * rhs; });
50 case AffineExprKind::Mod:
51 return constantFoldBinExpr(
52 expr, [](int64_t lhs, int64_t rhs) { return mod(lhs, rhs); });
53 case AffineExprKind::FloorDiv:
54 return constantFoldBinExpr(
55 expr, [](int64_t lhs, int64_t rhs) { return floorDiv(lhs, rhs); });
56 case AffineExprKind::CeilDiv:
57 return constantFoldBinExpr(
58 expr, [](int64_t lhs, int64_t rhs) { return ceilDiv(lhs, rhs); });
59 case AffineExprKind::Constant:
60 return expr.cast<AffineConstantExpr>().getValue();
61 case AffineExprKind::DimId:
62 if (auto attr = operandConsts[expr.cast<AffineDimExpr>().getPosition()]
63 .dyn_cast_or_null<IntegerAttr>())
64 return attr.getInt();
65 return llvm::None;
66 case AffineExprKind::SymbolId:
67 if (auto attr = operandConsts[numDims +
68 expr.cast<AffineSymbolExpr>().getPosition()]
69 .dyn_cast_or_null<IntegerAttr>())
70 return attr.getInt();
71 return llvm::None;
72 }
73 llvm_unreachable("Unknown AffineExpr");
74 }
75
76 // TODO: Change these to operate on APInts too.
constantFoldBinExpr(AffineExpr expr,int64_t (* op)(int64_t,int64_t))77 Optional<int64_t> constantFoldBinExpr(AffineExpr expr,
78 int64_t (*op)(int64_t, int64_t)) {
79 auto binOpExpr = expr.cast<AffineBinaryOpExpr>();
80 if (auto lhs = constantFoldImpl(binOpExpr.getLHS()))
81 if (auto rhs = constantFoldImpl(binOpExpr.getRHS()))
82 return op(*lhs, *rhs);
83 return llvm::None;
84 }
85
86 // The number of dimension operands in AffineMap containing this expression.
87 unsigned numDims;
88 // The constant valued operands used to evaluate this AffineExpr.
89 ArrayRef<Attribute> operandConsts;
90 };
91
92 } // namespace
93
94 /// Returns a single constant result affine map.
getConstantMap(int64_t val,MLIRContext * context)95 AffineMap AffineMap::getConstantMap(int64_t val, MLIRContext *context) {
96 return get(/*dimCount=*/0, /*symbolCount=*/0,
97 {getAffineConstantExpr(val, context)});
98 }
99
100 /// Returns an identity affine map (d0, ..., dn) -> (dp, ..., dn) on the most
101 /// minor dimensions.
getMinorIdentityMap(unsigned dims,unsigned results,MLIRContext * context)102 AffineMap AffineMap::getMinorIdentityMap(unsigned dims, unsigned results,
103 MLIRContext *context) {
104 assert(dims >= results && "Dimension mismatch");
105 auto id = AffineMap::getMultiDimIdentityMap(dims, context);
106 return AffineMap::get(dims, 0, id.getResults().take_back(results), context);
107 }
108
isMinorIdentity() const109 bool AffineMap::isMinorIdentity() const {
110 return getNumDims() >= getNumResults() &&
111 *this ==
112 getMinorIdentityMap(getNumDims(), getNumResults(), getContext());
113 }
114
115 /// Returns true if this affine map is a minor identity up to broadcasted
116 /// dimensions which are indicated by value 0 in the result.
isMinorIdentityWithBroadcasting(SmallVectorImpl<unsigned> * broadcastedDims) const117 bool AffineMap::isMinorIdentityWithBroadcasting(
118 SmallVectorImpl<unsigned> *broadcastedDims) const {
119 if (broadcastedDims)
120 broadcastedDims->clear();
121 if (getNumDims() < getNumResults())
122 return false;
123 unsigned suffixStart = getNumDims() - getNumResults();
124 for (const auto &idxAndExpr : llvm::enumerate(getResults())) {
125 unsigned resIdx = idxAndExpr.index();
126 AffineExpr expr = idxAndExpr.value();
127 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
128 // Each result may be either a constant 0 (broadcasted dimension).
129 if (constExpr.getValue() != 0)
130 return false;
131 if (broadcastedDims)
132 broadcastedDims->push_back(resIdx);
133 } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
134 // Or it may be the input dimension corresponding to this result position.
135 if (dimExpr.getPosition() != suffixStart + resIdx)
136 return false;
137 } else {
138 return false;
139 }
140 }
141 return true;
142 }
143
144 /// Return true if this affine map can be converted to a minor identity with
145 /// broadcast by doing a permute. Return a permutation (there may be
146 /// several) to apply to get to a minor identity with broadcasts.
147 /// Ex:
148 /// * (d0, d1, d2) -> (0, d1) maps to minor identity (d1, 0 = d2) with
149 /// perm = [1, 0] and broadcast d2
150 /// * (d0, d1, d2) -> (d0, 0) cannot be mapped to a minor identity by
151 /// permutation + broadcast
152 /// * (d0, d1, d2, d3) -> (0, d1, d3) maps to minor identity (d1, 0 = d2, d3)
153 /// with perm = [1, 0, 2] and broadcast d2
154 /// * (d0, d1) -> (d1, 0, 0, d0) maps to minor identity (d0, d1) with extra
155 /// leading broadcat dimensions. The map returned would be (0, 0, d0, d1) with
156 /// perm = [3, 0, 1, 2]
isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl<unsigned> & permutedDims) const157 bool AffineMap::isPermutationOfMinorIdentityWithBroadcasting(
158 SmallVectorImpl<unsigned> &permutedDims) const {
159 unsigned projectionStart =
160 getNumResults() < getNumInputs() ? getNumInputs() - getNumResults() : 0;
161 permutedDims.clear();
162 SmallVector<unsigned> broadcastDims;
163 permutedDims.resize(getNumResults(), 0);
164 // If there are more results than input dimensions we want the new map to
165 // start with broadcast dimensions in order to be a minor identity with
166 // broadcasting.
167 unsigned leadingBroadcast =
168 getNumResults() > getNumInputs() ? getNumResults() - getNumInputs() : 0;
169 llvm::SmallBitVector dimFound(std::max(getNumInputs(), getNumResults()),
170 false);
171 for (const auto &idxAndExpr : llvm::enumerate(getResults())) {
172 unsigned resIdx = idxAndExpr.index();
173 AffineExpr expr = idxAndExpr.value();
174 // Each result may be either a constant 0 (broadcast dimension) or a
175 // dimension.
176 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
177 if (constExpr.getValue() != 0)
178 return false;
179 broadcastDims.push_back(resIdx);
180 } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
181 if (dimExpr.getPosition() < projectionStart)
182 return false;
183 unsigned newPosition =
184 dimExpr.getPosition() - projectionStart + leadingBroadcast;
185 permutedDims[resIdx] = newPosition;
186 dimFound[newPosition] = true;
187 } else {
188 return false;
189 }
190 }
191 // Find a permuation for the broadcast dimension. Since they are broadcasted
192 // any valid permutation is acceptable. We just permute the dim into a slot
193 // without an existing dimension.
194 unsigned pos = 0;
195 for (auto dim : broadcastDims) {
196 while (pos < dimFound.size() && dimFound[pos]) {
197 pos++;
198 }
199 permutedDims[dim] = pos++;
200 }
201 return true;
202 }
203
204 /// Returns an AffineMap representing a permutation.
getPermutationMap(ArrayRef<unsigned> permutation,MLIRContext * context)205 AffineMap AffineMap::getPermutationMap(ArrayRef<unsigned> permutation,
206 MLIRContext *context) {
207 assert(!permutation.empty() &&
208 "Cannot create permutation map from empty permutation vector");
209 SmallVector<AffineExpr, 4> affExprs;
210 for (auto index : permutation)
211 affExprs.push_back(getAffineDimExpr(index, context));
212 const auto *m = std::max_element(permutation.begin(), permutation.end());
213 auto permutationMap = AffineMap::get(*m + 1, 0, affExprs, context);
214 assert(permutationMap.isPermutation() && "Invalid permutation vector");
215 return permutationMap;
216 }
217
218 template <typename AffineExprContainer>
219 static SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<AffineExprContainer> exprsList)220 inferFromExprList(ArrayRef<AffineExprContainer> exprsList) {
221 assert(!exprsList.empty());
222 assert(!exprsList[0].empty());
223 auto context = exprsList[0][0].getContext();
224 int64_t maxDim = -1, maxSym = -1;
225 getMaxDimAndSymbol(exprsList, maxDim, maxSym);
226 SmallVector<AffineMap, 4> maps;
227 maps.reserve(exprsList.size());
228 for (const auto &exprs : exprsList)
229 maps.push_back(AffineMap::get(/*dimCount=*/maxDim + 1,
230 /*symbolCount=*/maxSym + 1, exprs, context));
231 return maps;
232 }
233
234 SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList)235 AffineMap::inferFromExprList(ArrayRef<ArrayRef<AffineExpr>> exprsList) {
236 return ::inferFromExprList(exprsList);
237 }
238
239 SmallVector<AffineMap, 4>
inferFromExprList(ArrayRef<SmallVector<AffineExpr,4>> exprsList)240 AffineMap::inferFromExprList(ArrayRef<SmallVector<AffineExpr, 4>> exprsList) {
241 return ::inferFromExprList(exprsList);
242 }
243
getMultiDimIdentityMap(unsigned numDims,MLIRContext * context)244 AffineMap AffineMap::getMultiDimIdentityMap(unsigned numDims,
245 MLIRContext *context) {
246 SmallVector<AffineExpr, 4> dimExprs;
247 dimExprs.reserve(numDims);
248 for (unsigned i = 0; i < numDims; ++i)
249 dimExprs.push_back(mlir::getAffineDimExpr(i, context));
250 return get(/*dimCount=*/numDims, /*symbolCount=*/0, dimExprs, context);
251 }
252
getContext() const253 MLIRContext *AffineMap::getContext() const { return map->context; }
254
isIdentity() const255 bool AffineMap::isIdentity() const {
256 if (getNumDims() != getNumResults())
257 return false;
258 ArrayRef<AffineExpr> results = getResults();
259 for (unsigned i = 0, numDims = getNumDims(); i < numDims; ++i) {
260 auto expr = results[i].dyn_cast<AffineDimExpr>();
261 if (!expr || expr.getPosition() != i)
262 return false;
263 }
264 return true;
265 }
266
isEmpty() const267 bool AffineMap::isEmpty() const {
268 return getNumDims() == 0 && getNumSymbols() == 0 && getNumResults() == 0;
269 }
270
isSingleConstant() const271 bool AffineMap::isSingleConstant() const {
272 return getNumResults() == 1 && getResult(0).isa<AffineConstantExpr>();
273 }
274
isConstant() const275 bool AffineMap::isConstant() const {
276 return llvm::all_of(getResults(), [](AffineExpr expr) {
277 return expr.isa<AffineConstantExpr>();
278 });
279 }
280
getSingleConstantResult() const281 int64_t AffineMap::getSingleConstantResult() const {
282 assert(isSingleConstant() && "map must have a single constant result");
283 return getResult(0).cast<AffineConstantExpr>().getValue();
284 }
285
getConstantResults() const286 SmallVector<int64_t> AffineMap::getConstantResults() const {
287 assert(isConstant() && "map must have only constant results");
288 SmallVector<int64_t> result;
289 for (auto expr : getResults())
290 result.emplace_back(expr.cast<AffineConstantExpr>().getValue());
291 return result;
292 }
293
getNumDims() const294 unsigned AffineMap::getNumDims() const {
295 assert(map && "uninitialized map storage");
296 return map->numDims;
297 }
getNumSymbols() const298 unsigned AffineMap::getNumSymbols() const {
299 assert(map && "uninitialized map storage");
300 return map->numSymbols;
301 }
getNumResults() const302 unsigned AffineMap::getNumResults() const { return getResults().size(); }
getNumInputs() const303 unsigned AffineMap::getNumInputs() const {
304 assert(map && "uninitialized map storage");
305 return map->numDims + map->numSymbols;
306 }
getResults() const307 ArrayRef<AffineExpr> AffineMap::getResults() const {
308 assert(map && "uninitialized map storage");
309 return map->results();
310 }
getResult(unsigned idx) const311 AffineExpr AffineMap::getResult(unsigned idx) const {
312 return getResults()[idx];
313 }
314
getDimPosition(unsigned idx) const315 unsigned AffineMap::getDimPosition(unsigned idx) const {
316 return getResult(idx).cast<AffineDimExpr>().getPosition();
317 }
318
getPermutedPosition(unsigned input) const319 unsigned AffineMap::getPermutedPosition(unsigned input) const {
320 assert(isPermutation() && "invalid permutation request");
321 for (unsigned i = 0, numResults = getNumResults(); i < numResults; i++)
322 if (getDimPosition(i) == input)
323 return i;
324 llvm_unreachable("incorrect permutation request");
325 }
326
327 /// Folds the results of the application of an affine map on the provided
328 /// operands to a constant if possible. Returns false if the folding happens,
329 /// true otherwise.
330 LogicalResult
constantFold(ArrayRef<Attribute> operandConstants,SmallVectorImpl<Attribute> & results) const331 AffineMap::constantFold(ArrayRef<Attribute> operandConstants,
332 SmallVectorImpl<Attribute> &results) const {
333 // Attempt partial folding.
334 SmallVector<int64_t, 2> integers;
335 partialConstantFold(operandConstants, &integers);
336
337 // If all expressions folded to a constant, populate results with attributes
338 // containing those constants.
339 if (integers.empty())
340 return failure();
341
342 auto range = llvm::map_range(integers, [this](int64_t i) {
343 return IntegerAttr::get(IndexType::get(getContext()), i);
344 });
345 results.append(range.begin(), range.end());
346 return success();
347 }
348
349 AffineMap
partialConstantFold(ArrayRef<Attribute> operandConstants,SmallVectorImpl<int64_t> * results) const350 AffineMap::partialConstantFold(ArrayRef<Attribute> operandConstants,
351 SmallVectorImpl<int64_t> *results) const {
352 assert(getNumInputs() == operandConstants.size());
353
354 // Fold each of the result expressions.
355 AffineExprConstantFolder exprFolder(getNumDims(), operandConstants);
356 SmallVector<AffineExpr, 4> exprs;
357 exprs.reserve(getNumResults());
358
359 for (auto expr : getResults()) {
360 auto folded = exprFolder.constantFold(expr);
361 // If did not fold to a constant, keep the original expression, and clear
362 // the integer results vector.
363 if (folded) {
364 exprs.push_back(
365 getAffineConstantExpr(folded.getInt(), folded.getContext()));
366 if (results)
367 results->push_back(folded.getInt());
368 } else {
369 exprs.push_back(expr);
370 if (results) {
371 results->clear();
372 results = nullptr;
373 }
374 }
375 }
376
377 return get(getNumDims(), getNumSymbols(), exprs, getContext());
378 }
379
380 /// Walk all of the AffineExpr's in this mapping. Each node in an expression
381 /// tree is visited in postorder.
walkExprs(llvm::function_ref<void (AffineExpr)> callback) const382 void AffineMap::walkExprs(llvm::function_ref<void(AffineExpr)> callback) const {
383 for (auto expr : getResults())
384 expr.walk(callback);
385 }
386
387 /// This method substitutes any uses of dimensions and symbols (e.g.
388 /// dim#0 with dimReplacements[0]) in subexpressions and returns the modified
389 /// expression mapping. Because this can be used to eliminate dims and
390 /// symbols, the client needs to specify the number of dims and symbols in
391 /// the result. The returned map always has the same number of results.
replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,ArrayRef<AffineExpr> symReplacements,unsigned numResultDims,unsigned numResultSyms) const392 AffineMap AffineMap::replaceDimsAndSymbols(ArrayRef<AffineExpr> dimReplacements,
393 ArrayRef<AffineExpr> symReplacements,
394 unsigned numResultDims,
395 unsigned numResultSyms) const {
396 SmallVector<AffineExpr, 8> results;
397 results.reserve(getNumResults());
398 for (auto expr : getResults())
399 results.push_back(
400 expr.replaceDimsAndSymbols(dimReplacements, symReplacements));
401 return get(numResultDims, numResultSyms, results, getContext());
402 }
403
404 /// Sparse replace method. Apply AffineExpr::replace(`expr`, `replacement`) to
405 /// each of the results and return a new AffineMap with the new results and
406 /// with the specified number of dims and symbols.
replace(AffineExpr expr,AffineExpr replacement,unsigned numResultDims,unsigned numResultSyms) const407 AffineMap AffineMap::replace(AffineExpr expr, AffineExpr replacement,
408 unsigned numResultDims,
409 unsigned numResultSyms) const {
410 SmallVector<AffineExpr, 4> newResults;
411 newResults.reserve(getNumResults());
412 for (AffineExpr e : getResults())
413 newResults.push_back(e.replace(expr, replacement));
414 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext());
415 }
416
417 /// Sparse replace method. Apply AffineExpr::replace(`map`) to each of the
418 /// results and return a new AffineMap with the new results and with the
419 /// specified number of dims and symbols.
replace(const DenseMap<AffineExpr,AffineExpr> & map,unsigned numResultDims,unsigned numResultSyms) const420 AffineMap AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map,
421 unsigned numResultDims,
422 unsigned numResultSyms) const {
423 SmallVector<AffineExpr, 4> newResults;
424 newResults.reserve(getNumResults());
425 for (AffineExpr e : getResults())
426 newResults.push_back(e.replace(map));
427 return AffineMap::get(numResultDims, numResultSyms, newResults, getContext());
428 }
429
430 AffineMap
replace(const DenseMap<AffineExpr,AffineExpr> & map) const431 AffineMap::replace(const DenseMap<AffineExpr, AffineExpr> &map) const {
432 SmallVector<AffineExpr, 4> newResults;
433 newResults.reserve(getNumResults());
434 for (AffineExpr e : getResults())
435 newResults.push_back(e.replace(map));
436 return AffineMap::inferFromExprList(newResults).front();
437 }
438
compose(AffineMap map) const439 AffineMap AffineMap::compose(AffineMap map) const {
440 assert(getNumDims() == map.getNumResults() && "Number of results mismatch");
441 // Prepare `map` by concatenating the symbols and rewriting its exprs.
442 unsigned numDims = map.getNumDims();
443 unsigned numSymbolsThisMap = getNumSymbols();
444 unsigned numSymbols = numSymbolsThisMap + map.getNumSymbols();
445 SmallVector<AffineExpr, 8> newDims(numDims);
446 for (unsigned idx = 0; idx < numDims; ++idx) {
447 newDims[idx] = getAffineDimExpr(idx, getContext());
448 }
449 SmallVector<AffineExpr, 8> newSymbols(numSymbols - numSymbolsThisMap);
450 for (unsigned idx = numSymbolsThisMap; idx < numSymbols; ++idx) {
451 newSymbols[idx - numSymbolsThisMap] =
452 getAffineSymbolExpr(idx, getContext());
453 }
454 auto newMap =
455 map.replaceDimsAndSymbols(newDims, newSymbols, numDims, numSymbols);
456 SmallVector<AffineExpr, 8> exprs;
457 exprs.reserve(getResults().size());
458 for (auto expr : getResults())
459 exprs.push_back(expr.compose(newMap));
460 return AffineMap::get(numDims, numSymbols, exprs, map.getContext());
461 }
462
compose(ArrayRef<int64_t> values) const463 SmallVector<int64_t, 4> AffineMap::compose(ArrayRef<int64_t> values) const {
464 assert(getNumSymbols() == 0 && "Expected symbol-less map");
465 SmallVector<AffineExpr, 4> exprs;
466 exprs.reserve(values.size());
467 MLIRContext *ctx = getContext();
468 for (auto v : values)
469 exprs.push_back(getAffineConstantExpr(v, ctx));
470 auto resMap = compose(AffineMap::get(0, 0, exprs, ctx));
471 SmallVector<int64_t, 4> res;
472 res.reserve(resMap.getNumResults());
473 for (auto e : resMap.getResults())
474 res.push_back(e.cast<AffineConstantExpr>().getValue());
475 return res;
476 }
477
isProjectedPermutation(bool allowZeroInResults) const478 bool AffineMap::isProjectedPermutation(bool allowZeroInResults) const {
479 if (getNumSymbols() > 0)
480 return false;
481
482 // Having more results than inputs means that results have duplicated dims or
483 // zeros that can't be mapped to input dims.
484 if (getNumResults() > getNumInputs())
485 return false;
486
487 SmallVector<bool, 8> seen(getNumInputs(), false);
488 // A projected permutation can have, at most, only one instance of each input
489 // dimension in the result expressions. Zeros are allowed as long as the
490 // number of result expressions is lower or equal than the number of input
491 // expressions.
492 for (auto expr : getResults()) {
493 if (auto dim = expr.dyn_cast<AffineDimExpr>()) {
494 if (seen[dim.getPosition()])
495 return false;
496 seen[dim.getPosition()] = true;
497 } else {
498 auto constExpr = expr.dyn_cast<AffineConstantExpr>();
499 if (!allowZeroInResults || !constExpr || constExpr.getValue() != 0)
500 return false;
501 }
502 }
503
504 // Results are either dims or zeros and zeros can be mapped to input dims.
505 return true;
506 }
507
isPermutation() const508 bool AffineMap::isPermutation() const {
509 if (getNumDims() != getNumResults())
510 return false;
511 return isProjectedPermutation();
512 }
513
getSubMap(ArrayRef<unsigned> resultPos) const514 AffineMap AffineMap::getSubMap(ArrayRef<unsigned> resultPos) const {
515 SmallVector<AffineExpr, 4> exprs;
516 exprs.reserve(resultPos.size());
517 for (auto idx : resultPos)
518 exprs.push_back(getResult(idx));
519 return AffineMap::get(getNumDims(), getNumSymbols(), exprs, getContext());
520 }
521
getSliceMap(unsigned start,unsigned length) const522 AffineMap AffineMap::getSliceMap(unsigned start, unsigned length) const {
523 return AffineMap::get(getNumDims(), getNumSymbols(),
524 getResults().slice(start, length), getContext());
525 }
526
getMajorSubMap(unsigned numResults) const527 AffineMap AffineMap::getMajorSubMap(unsigned numResults) const {
528 if (numResults == 0)
529 return AffineMap();
530 if (numResults > getNumResults())
531 return *this;
532 return getSliceMap(0, numResults);
533 }
534
getMinorSubMap(unsigned numResults) const535 AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
536 if (numResults == 0)
537 return AffineMap();
538 if (numResults > getNumResults())
539 return *this;
540 return getSliceMap(getNumResults() - numResults, numResults);
541 }
542
compressDims(AffineMap map,const llvm::SmallBitVector & unusedDims)543 AffineMap mlir::compressDims(AffineMap map,
544 const llvm::SmallBitVector &unusedDims) {
545 unsigned numDims = 0;
546 SmallVector<AffineExpr> dimReplacements;
547 dimReplacements.reserve(map.getNumDims());
548 MLIRContext *context = map.getContext();
549 for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) {
550 if (unusedDims.test(dim))
551 dimReplacements.push_back(getAffineConstantExpr(0, context));
552 else
553 dimReplacements.push_back(getAffineDimExpr(numDims++, context));
554 }
555 SmallVector<AffineExpr> resultExprs;
556 resultExprs.reserve(map.getNumResults());
557 for (auto e : map.getResults())
558 resultExprs.push_back(e.replaceDims(dimReplacements));
559 return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context);
560 }
561
compressUnusedDims(AffineMap map)562 AffineMap mlir::compressUnusedDims(AffineMap map) {
563 return compressDims(map, getUnusedDimsBitVector({map}));
564 }
565
566 static SmallVector<AffineMap>
compressUnusedImpl(ArrayRef<AffineMap> maps,llvm::function_ref<AffineMap (AffineMap)> compressionFun)567 compressUnusedImpl(ArrayRef<AffineMap> maps,
568 llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
569 if (maps.empty())
570 return SmallVector<AffineMap>();
571 SmallVector<AffineExpr> allExprs;
572 allExprs.reserve(maps.size() * maps.front().getNumResults());
573 unsigned numDims = maps.front().getNumDims(),
574 numSymbols = maps.front().getNumSymbols();
575 for (auto m : maps) {
576 assert(numDims == m.getNumDims() && numSymbols == m.getNumSymbols() &&
577 "expected maps with same num dims and symbols");
578 llvm::append_range(allExprs, m.getResults());
579 }
580 AffineMap unifiedMap = compressionFun(
581 AffineMap::get(numDims, numSymbols, allExprs, maps.front().getContext()));
582 unsigned unifiedNumDims = unifiedMap.getNumDims(),
583 unifiedNumSymbols = unifiedMap.getNumSymbols();
584 ArrayRef<AffineExpr> unifiedResults = unifiedMap.getResults();
585 SmallVector<AffineMap> res;
586 res.reserve(maps.size());
587 for (auto m : maps) {
588 res.push_back(AffineMap::get(unifiedNumDims, unifiedNumSymbols,
589 unifiedResults.take_front(m.getNumResults()),
590 m.getContext()));
591 unifiedResults = unifiedResults.drop_front(m.getNumResults());
592 }
593 return res;
594 }
595
compressUnusedDims(ArrayRef<AffineMap> maps)596 SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
597 return compressUnusedImpl(maps,
598 [](AffineMap m) { return compressUnusedDims(m); });
599 }
600
compressSymbols(AffineMap map,const llvm::SmallBitVector & unusedSymbols)601 AffineMap mlir::compressSymbols(AffineMap map,
602 const llvm::SmallBitVector &unusedSymbols) {
603 unsigned numSymbols = 0;
604 SmallVector<AffineExpr> symReplacements;
605 symReplacements.reserve(map.getNumSymbols());
606 MLIRContext *context = map.getContext();
607 for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) {
608 if (unusedSymbols.test(sym))
609 symReplacements.push_back(getAffineConstantExpr(0, context));
610 else
611 symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context));
612 }
613 SmallVector<AffineExpr> resultExprs;
614 resultExprs.reserve(map.getNumResults());
615 for (auto e : map.getResults())
616 resultExprs.push_back(e.replaceSymbols(symReplacements));
617 return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context);
618 }
619
compressUnusedSymbols(AffineMap map)620 AffineMap mlir::compressUnusedSymbols(AffineMap map) {
621 llvm::SmallBitVector unusedSymbols(map.getNumSymbols(), true);
622 map.walkExprs([&](AffineExpr expr) {
623 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
624 unusedSymbols.reset(symExpr.getPosition());
625 });
626 return compressSymbols(map, unusedSymbols);
627 }
628
compressUnusedSymbols(ArrayRef<AffineMap> maps)629 SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) {
630 return compressUnusedImpl(
631 maps, [](AffineMap m) { return compressUnusedSymbols(m); });
632 }
633
simplifyAffineMap(AffineMap map)634 AffineMap mlir::simplifyAffineMap(AffineMap map) {
635 SmallVector<AffineExpr, 8> exprs;
636 for (auto e : map.getResults()) {
637 exprs.push_back(
638 simplifyAffineExpr(e, map.getNumDims(), map.getNumSymbols()));
639 }
640 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), exprs,
641 map.getContext());
642 }
643
removeDuplicateExprs(AffineMap map)644 AffineMap mlir::removeDuplicateExprs(AffineMap map) {
645 auto results = map.getResults();
646 SmallVector<AffineExpr, 4> uniqueExprs(results.begin(), results.end());
647 uniqueExprs.erase(std::unique(uniqueExprs.begin(), uniqueExprs.end()),
648 uniqueExprs.end());
649 return AffineMap::get(map.getNumDims(), map.getNumSymbols(), uniqueExprs,
650 map.getContext());
651 }
652
inversePermutation(AffineMap map)653 AffineMap mlir::inversePermutation(AffineMap map) {
654 if (map.isEmpty())
655 return map;
656 assert(map.getNumSymbols() == 0 && "expected map without symbols");
657 SmallVector<AffineExpr, 4> exprs(map.getNumDims());
658 for (const auto &en : llvm::enumerate(map.getResults())) {
659 auto expr = en.value();
660 // Skip non-permutations.
661 if (auto d = expr.dyn_cast<AffineDimExpr>()) {
662 if (exprs[d.getPosition()])
663 continue;
664 exprs[d.getPosition()] = getAffineDimExpr(en.index(), d.getContext());
665 }
666 }
667 SmallVector<AffineExpr, 4> seenExprs;
668 seenExprs.reserve(map.getNumDims());
669 for (auto expr : exprs)
670 if (expr)
671 seenExprs.push_back(expr);
672 if (seenExprs.size() != map.getNumInputs())
673 return AffineMap();
674 return AffineMap::get(map.getNumResults(), 0, seenExprs, map.getContext());
675 }
676
inverseAndBroadcastProjectedPermutation(AffineMap map)677 AffineMap mlir::inverseAndBroadcastProjectedPermutation(AffineMap map) {
678 assert(map.isProjectedPermutation(/*allowZeroInResults=*/true));
679 MLIRContext *context = map.getContext();
680 AffineExpr zero = mlir::getAffineConstantExpr(0, context);
681 // Start with all the results as 0.
682 SmallVector<AffineExpr, 4> exprs(map.getNumInputs(), zero);
683 for (unsigned i : llvm::seq(unsigned(0), map.getNumResults())) {
684 // Skip zeros from input map. 'exprs' is already initialized to zero.
685 if (auto constExpr = map.getResult(i).dyn_cast<AffineConstantExpr>()) {
686 assert(constExpr.getValue() == 0 &&
687 "Unexpected constant in projected permutation");
688 (void)constExpr;
689 continue;
690 }
691
692 // Reverse each dimension existing in the original map result.
693 exprs[map.getDimPosition(i)] = getAffineDimExpr(i, context);
694 }
695 return AffineMap::get(map.getNumResults(), /*symbolCount=*/0, exprs, context);
696 }
697
concatAffineMaps(ArrayRef<AffineMap> maps)698 AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
699 unsigned numResults = 0, numDims = 0, numSymbols = 0;
700 for (auto m : maps)
701 numResults += m.getNumResults();
702 SmallVector<AffineExpr, 8> results;
703 results.reserve(numResults);
704 for (auto m : maps) {
705 for (auto res : m.getResults())
706 results.push_back(res.shiftSymbols(m.getNumSymbols(), numSymbols));
707
708 numSymbols += m.getNumSymbols();
709 numDims = std::max(m.getNumDims(), numDims);
710 }
711 return AffineMap::get(numDims, numSymbols, results,
712 maps.front().getContext());
713 }
714
getProjectedMap(AffineMap map,const llvm::SmallBitVector & unusedDims)715 AffineMap mlir::getProjectedMap(AffineMap map,
716 const llvm::SmallBitVector &unusedDims) {
717 return compressUnusedSymbols(compressDims(map, unusedDims));
718 }
719
getUnusedDimsBitVector(ArrayRef<AffineMap> maps)720 llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
721 unsigned numDims = maps[0].getNumDims();
722 llvm::SmallBitVector numDimsBitVector(numDims, true);
723 for (const auto &m : maps) {
724 for (unsigned i = 0; i < numDims; ++i) {
725 if (m.isFunctionOfDim(i))
726 numDimsBitVector.reset(i);
727 }
728 }
729 return numDimsBitVector;
730 }
731
732 //===----------------------------------------------------------------------===//
733 // MutableAffineMap.
734 //===----------------------------------------------------------------------===//
735
MutableAffineMap(AffineMap map)736 MutableAffineMap::MutableAffineMap(AffineMap map)
737 : results(map.getResults().begin(), map.getResults().end()),
738 numDims(map.getNumDims()), numSymbols(map.getNumSymbols()),
739 context(map.getContext()) {}
740
reset(AffineMap map)741 void MutableAffineMap::reset(AffineMap map) {
742 results.clear();
743 numDims = map.getNumDims();
744 numSymbols = map.getNumSymbols();
745 context = map.getContext();
746 llvm::append_range(results, map.getResults());
747 }
748
isMultipleOf(unsigned idx,int64_t factor) const749 bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
750 if (results[idx].isMultipleOf(factor))
751 return true;
752
753 // TODO: use simplifyAffineExpr and FlatAffineConstraints to
754 // complete this (for a more powerful analysis).
755 return false;
756 }
757
758 // Simplifies the result affine expressions of this map. The expressions have to
759 // be pure for the simplification implemented.
simplify()760 void MutableAffineMap::simplify() {
761 // Simplify each of the results if possible.
762 // TODO: functional-style map
763 for (unsigned i = 0, e = getNumResults(); i < e; i++) {
764 results[i] = simplifyAffineExpr(getResult(i), numDims, numSymbols);
765 }
766 }
767
getAffineMap() const768 AffineMap MutableAffineMap::getAffineMap() const {
769 return AffineMap::get(numDims, numSymbols, results, context);
770 }
771