1 //===- LinalgTransforms.cpp - Linalg transformations as patterns ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logic and helpers to expose Linalg transforms as rewrite
10 // patterns.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
15 #include "mlir/Dialect/Affine/Utils.h"
16 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
20 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
21 #include "mlir/Dialect/Vector/VectorOps.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/Matchers.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/ADT/ScopeExit.h"
28 #include "llvm/Support/Debug.h"
29 #include "llvm/Support/raw_ostream.h"
30 #include <type_traits>
31 
32 #define DEBUG_TYPE "linalg-transforms"
33 
34 using namespace mlir;
35 using namespace mlir::edsc;
36 using namespace mlir::edsc::intrinsics;
37 using namespace mlir::linalg;
38 
39 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ")
40 
41 //===----------------------------------------------------------------------===//
42 // Transformations exposed as rewrite patterns.
43 //===----------------------------------------------------------------------===//
44 // Marker used as attribute name in generated Linalg rewriting transformations.
45 const StringLiteral mlir::linalg::LinalgTransforms::kLinalgTransformMarker =
46     "__internal_linalg_transform__";
47 
48 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
49     ArrayRef<Identifier> matchDisjunction, Optional<Identifier> replacement)
50     : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
51       replacement(replacement) {}
52 
53 mlir::linalg::LinalgTransformationFilter::LinalgTransformationFilter(
54     FilterFunction f, ArrayRef<Identifier> matchDisjunction,
55     Optional<Identifier> replacement)
56     : filters(),
57       matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()),
58       replacement(replacement) {
59   if (f)
60     filters.push_back(f);
61 }
62 
63 LogicalResult mlir::linalg::LinalgTransformationFilter::checkAndNotify(
64     PatternRewriter &rewriter, Operation *op) const {
65   if (llvm::any_of(filters,
66                    [&](const FilterFunction &f) { return failed(f(op)); }))
67     return failure();
68 
69   auto attr = op->template getAttrOfType<StringAttr>(
70       LinalgTransforms::kLinalgTransformMarker);
71 
72   if (!attr) {
73     // 1. Has no filter case and matchDisjunction is empty.
74     if (matchDisjunction.empty())
75       return success();
76 
77     // 2. Has no filter but was expecting a filter.
78     return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
79       diag << " does not have any filter from list: ";
80       interleaveComma(matchDisjunction, diag);
81     });
82   }
83 
84   // 4. Match explicit filter.
85   for (auto filter : matchDisjunction)
86     if (attr.getValue() == filter)
87       return success();
88 
89   // 5. Fail to match.
90   return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) {
91     diag << " does not have any filter from list: ";
92     interleaveComma(matchDisjunction, diag);
93   });
94 }
95 
96 void mlir::linalg::LinalgTransformationFilter::
97     replaceLinalgTransformationFilter(PatternRewriter &rewriter,
98                                       Operation *op) const {
99   if (replacement.hasValue())
100     op->setAttr(LinalgTransforms::kLinalgTransformMarker,
101                 rewriter.getStringAttr(replacement.getValue()));
102   else
103     op->removeAttr(Identifier::get(LinalgTransforms::kLinalgTransformMarker,
104                                    rewriter.getContext()));
105 }
106 
107 LinalgTilingOptions &
108 mlir::linalg::LinalgTilingOptions::setTileSizes(ArrayRef<int64_t> ts) {
109   SmallVector<int64_t, 4> tileSizes(ts.begin(), ts.end());
110   tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) {
111     OpBuilder::InsertionGuard guard(b);
112     b.setInsertionPointToStart(
113         &op->getParentOfType<FuncOp>().getBody().front());
114     return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) {
115       Value v = b.create<ConstantIndexOp>(op->getLoc(), s);
116       return v;
117     }));
118   };
119   return *this;
120 }
121 
122 /// Try to compute a static bounding box for `operand`
123 /// Return success if either:
124 ///   1. The operand is already statically shaped, `result` is left unchanged.
125 ///   2. The operand is (partially) dynamic, `result` is the result of a freshly
126 ///      created PadTensorOp.
127 /// Return failure if the operand cannot be padded to a static shape.
128 static LogicalResult padOperandToSmallestStaticBoundingBox(
129     PatternRewriter &rewriter, linalg::LinalgOp opToPad, OpOperand &operand,
130     const LinalgTilingOptions &options, Value &result) {
131   auto tensorType = operand.get().getType().cast<RankedTensorType>();
132   // Already static shape, no need to pad.
133   if (tensorType.hasStaticShape())
134     return success();
135   auto subtensor = operand.get().getDefiningOp<SubTensorOp>();
136   // Not a subtensor, cannot construct a static bounding box.
137   if (!subtensor)
138     return failure();
139   SmallVector<int64_t> staticSizes;
140   staticSizes.reserve(tensorType.getRank());
141   auto shapedOp =
142       cast<OffsetSizeAndStrideOpInterface>(subtensor.getOperation());
143   for (auto size : shapedOp.getMixedSizes()) {
144     auto indexAttr = size.is<Attribute>()
145                          ? size.get<Attribute>().dyn_cast<IntegerAttr>()
146                          : linalg::getSmallestBoundingIndex(size.get<Value>());
147     // SmallestBoundingIndex must exist for all sizes.
148     // For now return an error if we can't find it.
149     if (!indexAttr)
150       return rewriter.notifyMatchFailure(
151           opToPad, "No constant bounding box can be found for padding");
152     staticSizes.push_back(indexAttr.getInt());
153   }
154   Value pad = options.paddingValueComputationFunction(rewriter, operand);
155   auto staticTensorType =
156       RankedTensorType::get(staticSizes, tensorType.getElementType());
157   result = linalg::PadTensorOp::createPadHighOp(
158       staticTensorType, operand.get(), pad, opToPad->getLoc(), rewriter);
159   return success();
160 }
161 
162 // Try to create a static bounding box around each operand of `res.op`.
163 // If successful, `res.op` is rewritten in static form with padded operands.
164 // `res.op` is updated to the cloned static form of the op on success.
165 static LogicalResult rewriteAsPaddedOp(PatternRewriter &rewriter,
166                                        TiledLinalgOp &res,
167                                        const LinalgTilingOptions &options) {
168   LinalgOp opToPad = res.op;
169   Location loc = opToPad->getLoc();
170 
171   // If the op is fully static, it does not need padding.
172   // TODO: there are cases where we may still want to pad to larger sizes.
173   if (llvm::all_of(opToPad.getShapedOperands(), [](Value v) {
174         return v.getType().cast<RankedTensorType>().hasStaticShape();
175       }))
176     return success();
177 
178   OpBuilder::InsertionGuard g(rewriter);
179   // Set IP after op because we also take the dims of the original output.
180   rewriter.setInsertionPointAfter(opToPad);
181   // Make a copy of the shaped operands and update it.
182   SmallVector<Value> newOperands;
183   newOperands.reserve(opToPad.getNumShapedOperands());
184   for (OpOperand &operand : opToPad.getShapedOpOperands()) {
185     Value paddedOperand;
186     // If padding was requested but the shape cannot be bounded statically then
187     // the pattern fails to apply.
188     if (failed(padOperandToSmallestStaticBoundingBox(rewriter, opToPad, operand,
189                                                      options, paddedOperand))) {
190       return failure();
191     }
192     newOperands.push_back(paddedOperand ? paddedOperand : operand.get());
193   }
194 
195   // Clone `opToPad` to operate on the statically padded shapes.
196   auto resultTensorTypes =
197       ValueRange(newOperands).take_back(opToPad.getNumOutputs()).getTypes();
198   ValueRange otherOperands = opToPad.getAssumedNonShapedOperands();
199   newOperands.append(otherOperands.begin(), otherOperands.end());
200   linalg::LinalgOp paddedOp =
201       opToPad.clone(rewriter, loc, resultTensorTypes, newOperands);
202 
203   // Recover the subtensor out of the new static results. This keeps the
204   // original linalg op around because it uses the dims of the original results.
205   // This later folds away.
206   SmallVector<Value> paddedSubviewResults;
207   paddedSubviewResults.reserve(opToPad->getNumResults());
208   SetVector<Operation *> newUsersOfOpToPad;
209   for (auto it : llvm::zip(opToPad->getResults(), paddedOp->getResults())) {
210     auto rank = std::get<0>(it).getType().cast<RankedTensorType>().getRank();
211     SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
212     auto sizes = llvm::to_vector<4>(llvm::map_range(
213         llvm::seq<unsigned>(0, rank), [&](unsigned d) -> OpFoldResult {
214           auto dimOp = rewriter.create<memref::DimOp>(loc, std::get<0>(it), d);
215           newUsersOfOpToPad.insert(dimOp);
216           return dimOp.getResult();
217         }));
218     SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
219     paddedSubviewResults.push_back(rewriter.create<SubTensorOp>(
220         loc, std::get<1>(it), offsets, sizes, strides));
221   }
222   // Replace the transient `opToPad` locally, except for uses that we just
223   // created for the purpose of extracting the dims.
224   rewriter.replaceOpWithIf(opToPad, paddedSubviewResults, [&](OpOperand &opOp) {
225     return !newUsersOfOpToPad.contains(opOp.getOwner());
226   });
227 
228   res = TiledLinalgOp{paddedOp, res.loops, res.tensorResults};
229   return success();
230 }
231 
232 /// Linalg base tiling pattern.
233 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
234     StringRef opName, MLIRContext *context, LinalgTilingOptions options,
235     LinalgTransformationFilter filter, PatternBenefit benefit)
236     : RewritePattern(opName, benefit, context), filter(filter),
237       options(options) {}
238 
239 mlir::linalg::LinalgBaseTilingPattern::LinalgBaseTilingPattern(
240     MLIRContext *context, LinalgTilingOptions options,
241     LinalgTransformationFilter filter, PatternBenefit benefit)
242     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter),
243       options(options) {}
244 
245 LogicalResult mlir::linalg::LinalgBaseTilingPattern::matchAndRewriteBase(
246     Operation *op, PatternRewriter &rewriter, TiledLinalgOp &result) const {
247   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
248   if (!linalgOp)
249     return failure();
250   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
251     return failure();
252 
253   Optional<TiledLinalgOp> res = tileLinalgOp(rewriter, linalgOp, options);
254 
255   if (!res)
256     return failure();
257 
258   // Setup RAII guard to return properly.
259   LinalgOp tiledOp = res->op;
260   auto guard = llvm::make_scope_exit([&]() {
261     // Return relevant information to derived pattern.
262     result = *res;
263     // Replace filter on both tiledOp and tiledAndPaddedOp, if necessary.
264     filter.replaceLinalgTransformationFilter(rewriter, tiledOp);
265     if (tiledOp != res->op)
266       filter.replaceLinalgTransformationFilter(rewriter, res->op);
267   });
268 
269   // Consider padding on the fly only if the op has tensor semantics.
270   if (!options.paddingValueComputationFunction ||
271       !linalgOp.hasTensorSemantics())
272     return success();
273 
274   // Try to pad on the fly by rewriting res->op as a padded op.
275   if (failed(rewriteAsPaddedOp(rewriter, *res, options))) {
276     // Set so RAII guard does not propagate TiledLinalgOp to `result`.
277     return failure();
278   }
279 
280   // Do not perform replacement of `linalgOp`, let the derived patterns
281   // do this as they see fit, from the resulting TiledLinalgOp.
282   return success();
283 }
284 
285 static ValueRange getTiledOpResult(TiledLinalgOp tiledOp) {
286   if (tiledOp.loops.empty())
287     return tiledOp.op.getOperation()->getResults();
288   return tiledOp.loops.front()->getResults();
289 }
290 
291 static ValueRange
292 getTiledAndFusedOpResult(TiledAndFusedLinalgOps tiledAndFusedOp) {
293   if (tiledAndFusedOp.fusedLoops.empty())
294     return tiledAndFusedOp.op.getOperation()->getResults();
295   return tiledAndFusedOp.fusedLoops.front()->getResults();
296 }
297 
298 mlir::linalg::LinalgBaseTileAndFusePattern::LinalgBaseTileAndFusePattern(
299     StringRef opName, MLIRContext *context,
300     const LinalgDependenceGraph &dependenceGraph,
301     LinalgTilingOptions tilingOptions, LinalgFusionOptions fusionOptions,
302     LinalgTransformationFilter filter, LinalgTransformationFilter fusedOpMarker,
303     LinalgTransformationFilter originalOpMarker, PatternBenefit benefit)
304     : RewritePattern(opName, benefit, context, {}),
305       dependenceGraph(dependenceGraph), tilingOptions(tilingOptions),
306       fusionOptions(fusionOptions), filter(filter),
307       fusedOpMarker(fusedOpMarker), originalOpMarker(originalOpMarker) {}
308 
309 LogicalResult mlir::linalg::LinalgBaseTileAndFusePattern::matchAndRewrite(
310     Operation *op, PatternRewriter &rewriter) const {
311   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
312   // TODO: remove hasIndexSemantics check once index ops are supported.
313   if (!linalgOp || linalgOp.hasIndexSemantics())
314     return failure();
315   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
316     return failure();
317 
318   DenseSet<Operation *> producers;
319   producers.insert(linalgOp);
320   for (auto dependence : dependenceGraph.getDependentOperationsInto(linalgOp)) {
321     Optional<unsigned> operandNumber = dependence.getIndexingOpViewOperandNum();
322     // When looking at dependences into, indexingOp is always OpOperand. We
323     // could assert, but continue if this is not the case.
324     if (!operandNumber)
325       continue;
326     if (!fusionOptions.indicesToFuse.count(operandNumber.getValue()))
327       continue;
328     if (isa<LinalgOp>(dependence.getDependentOp()))
329       producers.insert(dependence.getDependentOp());
330   }
331 
332   SmallVector<LinalgOp, 1> fusionOps;
333   for (auto it = op->getBlock()->begin(), ie = Block::iterator(op); it != ie;
334        ++it) {
335     auto producerLinalgOp = dyn_cast<LinalgOp>(&(*it));
336     if (producerLinalgOp && producers.count(producerLinalgOp))
337       fusionOps.push_back(producerLinalgOp);
338   }
339   fusionOps.push_back(linalgOp);
340 
341   SmallVector<Value, 4> tileSizes =
342       tilingOptions.tileSizeComputationFunction(rewriter, op);
343   LinalgTilingOptions instanceTilingOptions = tilingOptions;
344   instanceTilingOptions.setTileSizes(tileSizes);
345   Optional<TiledAndFusedLinalgOps> tiledAndFusedOps = tileAndFuseLinalgOps(
346       rewriter, fusionOps, dependenceGraph, instanceTilingOptions);
347   if (!tiledAndFusedOps)
348     return failure();
349 
350   // Tile the unfused loops;
351   SmallVector<Value, 4> unfusedLoopTileSizes;
352   Value zero = rewriter.create<ConstantIndexOp>(op->getLoc(), 0);
353   for (auto tileSize : enumerate(tileSizes)) {
354     if (tiledAndFusedOps->fusedLoopDims.count(tileSize.index()))
355       unfusedLoopTileSizes.push_back(zero);
356     else
357       unfusedLoopTileSizes.push_back(tileSize.value());
358   }
359   // Tile the loop only if there is a non-zero tile size.
360   if (unfusedLoopTileSizes.size() > linalgOp.getNumLoops())
361     unfusedLoopTileSizes.resize(linalgOp.getNumLoops());
362   if (llvm::any_of(unfusedLoopTileSizes, [](Value val) {
363         if (auto cst = val.getDefiningOp<ConstantIndexOp>())
364           return cst.getValue() != 0;
365         return true;
366       })) {
367     LinalgTilingOptions unfusedTilingOptions = tilingOptions;
368     unfusedTilingOptions.setTileSizes(unfusedLoopTileSizes);
369     Optional<TiledLinalgOp> unfusedTiledOp =
370         tileLinalgOp(rewriter, tiledAndFusedOps->op, unfusedTilingOptions);
371     if (!unfusedTiledOp)
372       return failure();
373     rewriter.replaceOp(tiledAndFusedOps->op,
374                        getTiledOpResult(unfusedTiledOp.getValue()));
375     tiledAndFusedOps->op = unfusedTiledOp->op;
376   }
377   op->replaceAllUsesWith(getTiledAndFusedOpResult(tiledAndFusedOps.getValue()));
378 
379   filter.replaceLinalgTransformationFilter(rewriter,
380                                            tiledAndFusedOps->op.getOperation());
381   for (auto fusedOp : tiledAndFusedOps->fusedProducers) {
382     fusedOpMarker.replaceLinalgTransformationFilter(rewriter,
383                                                     fusedOp.getOperation());
384   }
385   for (auto origProducerOp : ArrayRef<LinalgOp>(fusionOps).drop_back()) {
386     originalOpMarker.replaceLinalgTransformationFilter(
387         rewriter, origProducerOp.getOperation());
388   }
389   rewriter.updateRootInPlace(op, [&]() {
390     originalOpMarker.replaceLinalgTransformationFilter(rewriter, op);
391   });
392   return success();
393 }
394 
395 /// Linalg generic interchange pattern.
396 mlir::linalg::GenericOpInterchangePattern::GenericOpInterchangePattern(
397     MLIRContext *context, ArrayRef<unsigned> interchangeVector,
398     LinalgTransformationFilter filter, PatternBenefit benefit)
399     : OpRewritePattern(context, benefit), filter(filter),
400       interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
401 
402 LogicalResult mlir::linalg::GenericOpInterchangePattern::matchAndRewrite(
403     GenericOp genericOp, PatternRewriter &rewriter) const {
404   if (failed(filter.checkAndNotify(rewriter, genericOp)))
405     return failure();
406   if (failed(interchangeGenericOpPrecondition(genericOp, interchangeVector)))
407     return failure();
408 
409   // TODO: figure out how this interplays with named ops. In particular this
410   // should break the named op property.
411   rewriter.updateRootInPlace(genericOp, [&]() {
412     interchangeGenericOp(rewriter, genericOp, interchangeVector);
413     // New filter if specified.
414     filter.replaceLinalgTransformationFilter(rewriter, genericOp);
415   });
416   return success();
417 }
418 
419 mlir::linalg::LinalgBasePromotionPattern::LinalgBasePromotionPattern(
420     StringRef opName, MLIRContext *context, LinalgPromotionOptions options,
421     LinalgTransformationFilter filter, PatternBenefit benefit)
422     : RewritePattern(opName, benefit, context, {}), filter(filter),
423       options(options) {}
424 
425 LogicalResult mlir::linalg::LinalgBasePromotionPattern::matchAndRewrite(
426     Operation *op, PatternRewriter &rewriter) const {
427   if (failed(filter.checkAndNotify(rewriter, op)))
428     return failure();
429   if (failed(promoteSubviewsPrecondition(op, options)))
430     return failure();
431 
432   // TODO: We cannot use root update here. This pattern is creating other ops,
433   // so if the promotion fails, those need to be cleaned up, which doesnt seem
434   // to be happening here. So to fail properly, we should be cloning the op and
435   // deleting the previous op. This needs more investigation.
436   rewriter.startRootUpdate(op);
437   Optional<LinalgOp> promotedOp = promoteSubViews(rewriter, op, options);
438   if (!promotedOp) {
439     rewriter.cancelRootUpdate(op);
440     return op->emitError("subview promotion failed");
441   }
442   rewriter.finalizeRootUpdate(op);
443   filter.replaceLinalgTransformationFilter(rewriter, op);
444   return success();
445 }
446 
447 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
448     MLIRContext *context, LinalgTransformationFilter filter,
449     PatternBenefit benefit)
450     : RewritePattern(MatchAnyOpTypeTag(), benefit, context), filter(filter) {}
451 
452 mlir::linalg::LinalgBaseVectorizationPattern::LinalgBaseVectorizationPattern(
453     StringRef opName, MLIRContext *context, LinalgTransformationFilter filter,
454     PatternBenefit benefit)
455     : RewritePattern(opName, benefit, context, {}), filter(filter) {}
456 
457 LogicalResult mlir::linalg::LinalgBaseVectorizationPattern::matchAndRewrite(
458     Operation *op, PatternRewriter &rewriter) const {
459   LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
460   if (!linalgOp)
461     return failure();
462   if (failed(filter.checkAndNotify(rewriter, linalgOp)))
463     return failure();
464   SmallVector<Value> newResults;
465   if (failed(vectorizeLinalgOp(rewriter, op, newResults)))
466     return failure();
467   if (!newResults.empty())
468     rewriter.replaceOp(op, newResults);
469   else
470     rewriter.eraseOp(op);
471   return success();
472 }
473 
474 LogicalResult mlir::linalg::applyStagedPatterns(
475     Operation *op, ArrayRef<FrozenRewritePatternSet> stage1Patterns,
476     const FrozenRewritePatternSet &stage2Patterns,
477     function_ref<LogicalResult(Operation *)> stage3Lambda) {
478   unsigned iteration = 0;
479   (void)iteration;
480   for (const auto &patterns : stage1Patterns) {
481     LLVM_DEBUG(DBGS() << "Before 1st stage, iter: " << ++iteration << "\n"
482                       << *op);
483     if (failed(applyPatternsAndFoldGreedily(op, patterns))) {
484       LLVM_DEBUG(DBGS() << "Underlying first stage rewrite did not converge");
485       return failure();
486     }
487     LLVM_DEBUG(DBGS() << "After 1st stage, iter: " << ++iteration << "\n"
488                       << *op);
489     if (failed(applyPatternsAndFoldGreedily(op, stage2Patterns))) {
490       LLVM_DEBUG(DBGS() << "Underlying 2nd stage rewrite did not converge");
491       return failure();
492     }
493     LLVM_DEBUG(DBGS() << "After 2nd stage, iter : " << iteration << "\n"
494                       << *op);
495     if (stage3Lambda) {
496       if (failed(stage3Lambda(op)))
497         return failure();
498       LLVM_DEBUG(DBGS() << "After 3rd stage, iter : " << iteration << "\n"
499                         << *op);
500     }
501   }
502   return success();
503 }
504 
505 /// Traverse the `dims` and substitute known min or max expressions returned by
506 /// the lambda |getMinMaxExpr|.
507 static AffineMap substitute(AffineMap map, SmallVectorImpl<Value> &dims,
508                             SmallVectorImpl<Value> &symbols,
509                             GetMinMaxExprFn getMinMaxExpr) {
510   auto exprs = llvm::to_vector<4>(map.getResults());
511   for (AffineExpr &expr : exprs) {
512     bool substituted = true;
513     while (substituted) {
514       substituted = false;
515       for (unsigned dimIdx = 0; dimIdx < dims.size(); ++dimIdx) {
516         Value dim = dims[dimIdx];
517         auto minMax = getMinMaxExpr(dim, dims, symbols);
518         if (!minMax)
519           continue;
520         AffineExpr dimExpr = getAffineDimExpr(dimIdx, expr.getContext());
521         LLVM_DEBUG(DBGS() << "Subst: " << dim << " @ " << dimExpr << "\n");
522         LLVM_DEBUG(DBGS() << "Before: " << expr << "\n");
523         // Substitute occurrences of `dimExpr` by either the min expression or
524         // the max expression depending on whether the value is used with a
525         // positive or negative  coefficient.
526         AffineExpr substitutedExpr =
527             substWithMin(expr, dimExpr, minMax->first, minMax->second);
528         LLVM_DEBUG(DBGS() << "After: " << substitutedExpr << "\n");
529         substituted = (substitutedExpr != expr);
530         expr = substitutedExpr;
531       }
532     }
533 
534     // Cleanup and simplify the results.
535     // This needs to happen outside of the loop iterating on dims.size() since
536     // it modifies dims.
537     SmallVector<Value, 4> operands(dims.begin(), dims.end());
538     operands.append(symbols.begin(), symbols.end());
539     auto map = AffineMap::get(dims.size(), symbols.size(), exprs,
540                               exprs.front().getContext());
541 
542     LLVM_DEBUG({
543       DBGS() << "Map to simplify: " << map << "\n";
544       DBGS() << "Operands:\n";
545       for (Value v : operands)
546         DBGS() << v << "\n";
547     });
548 
549     // Pull in affine.apply operations and compose them fully into the
550     // result.
551     fullyComposeAffineMapAndOperands(&map, &operands);
552     canonicalizeMapAndOperands(&map, &operands);
553     map = simplifyAffineMap(map);
554     // Assign the results.
555     exprs.assign(map.getResults().begin(), map.getResults().end());
556     dims.assign(operands.begin(), operands.begin() + map.getNumDims());
557     symbols.assign(operands.begin() + map.getNumDims(), operands.end());
558 
559     LLVM_DEBUG(DBGS() << "Map simplified: " << map << "\n");
560   }
561 
562   assert(!exprs.empty() && "Unexpected empty exprs");
563   return AffineMap::get(dims.size(), symbols.size(), exprs, map.getContext());
564 }
565 
566 /// Traverse the dims of the AffineMap of `affineMinOp` and substitute
567 /// dimensions with known range by new expressions involving the min or max
568 /// expression:
569 ///   - If the AffineDimExpr mapped to a known value has a positive sign, it
570 ///     is replaced by the min expression.
571 ///   - If the AffineDimExpr mapped to a known value has a negative sign, it is
572 ///     replaced by the max expression.
573 /// All known values are iteratively replaced.
574 /// This is used as an intermediate step in computing bounding boxes and
575 /// canonicalize AffineMinOps. All dim and symbol operands are assumed to have
576 /// positive values (positive orthant assumptions).
577 /// Return a new AffineMap, dims and symbols that have been canonicalized and
578 /// simplified.
579 AffineMapAndOperands
580 mlir::linalg::substituteMin(AffineMinOp affineMinOp,
581                             GetMinMaxExprFn getMinMaxExpr) {
582   AffineMapAndOperands res{affineMinOp.getAffineMap(),
583                            SmallVector<Value>(affineMinOp.getDimOperands()),
584                            SmallVector<Value>(affineMinOp.getSymbolOperands())};
585   res.map = substitute(affineMinOp.getAffineMap(), res.dims, res.symbols,
586                        getMinMaxExpr);
587   return res;
588 }
589 
590 LogicalResult AffineMinRangeCanonicalizationPattern::matchAndRewrite(
591     AffineMinOp minOp, PatternRewriter &rewriter) const {
592   LLVM_DEBUG(DBGS() << "Canonicalize AffineMinSCF: " << *minOp.getOperation()
593                     << "\n");
594 
595   auto affineMapAndOperands = substituteMin(minOp, getMinMaxFn);
596   AffineMap map = affineMapAndOperands.map;
597 
598   LLVM_DEBUG(DBGS() << "Resulting map: " << map << "\n");
599 
600   // Check whether any of the expressions, when subtracted from all other
601   // expressions, produces only >= 0 constants. If so, it is the min.
602   for (auto e : minOp.getAffineMap().getResults()) {
603     LLVM_DEBUG(DBGS() << "Candidate min: " << e << "\n");
604     if (!e.isSymbolicOrConstant())
605       continue;
606 
607     auto isNonPositive = [](AffineExpr e) {
608       if (auto cst = e.dyn_cast<AffineConstantExpr>())
609         return cst.getValue() < 0;
610       return true;
611     };
612 
613     // Build the subMap and check everything is statically known to be
614     // positive.
615     SmallVector<AffineExpr, 4> subExprs;
616     subExprs.reserve(map.getNumResults());
617     for (auto ee : map.getResults())
618       subExprs.push_back(ee - e);
619     MLIRContext *ctx = minOp.getContext();
620     AffineMap subMap = simplifyAffineMap(
621         AffineMap::get(map.getNumDims(), map.getNumSymbols(), subExprs, ctx));
622     LLVM_DEBUG(DBGS() << "simplified subMap: " << subMap << "\n");
623     if (llvm::any_of(subMap.getResults(), isNonPositive))
624       continue;
625 
626     // Static min found.
627     if (auto cst = e.dyn_cast<AffineConstantExpr>()) {
628       rewriter.replaceOpWithNewOp<ConstantIndexOp>(minOp, cst.getValue());
629     } else {
630       auto resultMap = AffineMap::get(0, map.getNumSymbols(), {e}, ctx);
631       SmallVector<Value> resultOperands = affineMapAndOperands.dims;
632       llvm::append_range(resultOperands, affineMapAndOperands.symbols);
633       canonicalizeMapAndOperands(&resultMap, &resultOperands);
634       resultMap = simplifyAffineMap(resultMap);
635       rewriter.replaceOpWithNewOp<AffineApplyOp>(minOp, resultMap,
636                                                  resultOperands);
637     }
638     return success();
639   }
640 
641   return failure();
642 }
643