1 //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains cross-dialect canonicalization patterns that cannot be
10 // actual canonicalization patterns due to undesired additional dependencies.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/Passes.h"
18 #include "mlir/Dialect/SCF/SCF.h"
19 #include "mlir/Dialect/SCF/Transforms.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 
24 using namespace mlir;
25 using namespace mlir::scf;
26 
27 namespace {
28 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
29 ///
30 /// ```
31 /// %0 = ... : tensor<?x?xf32>
32 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
33 ///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
34 ///   ...
35 /// }
36 /// ```
37 ///
38 /// is folded to:
39 ///
40 /// ```
41 /// %0 = ... : tensor<?x?xf32>
42 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
43 ///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
44 ///   ...
45 /// }
46 /// ```
47 template <typename OpTy>
48 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
49   using OpRewritePattern<OpTy>::OpRewritePattern;
50 
51   LogicalResult matchAndRewrite(OpTy dimOp,
52                                 PatternRewriter &rewriter) const override {
53     auto blockArg = dimOp.source().template dyn_cast<BlockArgument>();
54     if (!blockArg)
55       return failure();
56     auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
57     if (!forOp)
58       return failure();
59 
60     Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
61     rewriter.updateRootInPlace(
62         dimOp, [&]() { dimOp.sourceMutable().assign(initArg); });
63 
64     return success();
65   };
66 };
67 
68 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
69 /// and scf.parallel loops with a known range.
70 template <typename OpTy, bool IsMin>
71 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
72   using OpRewritePattern<OpTy>::OpRewritePattern;
73 
74   LogicalResult matchAndRewrite(OpTy op,
75                                 PatternRewriter &rewriter) const override {
76     auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) {
77       if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
78         lb = forOp.lowerBound();
79         ub = forOp.upperBound();
80         step = forOp.step();
81         return success();
82       }
83       if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) {
84         for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
85           if (parOp.getInductionVars()[idx] == iv) {
86             lb = parOp.lowerBound()[idx];
87             ub = parOp.upperBound()[idx];
88             step = parOp.step()[idx];
89             return success();
90           }
91         }
92         return failure();
93       }
94       return failure();
95     };
96 
97     return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(),
98                                            op.operands(), IsMin, loopMatcher);
99   }
100 };
101 
102 struct SCFForLoopCanonicalization
103     : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
104   void runOnFunction() override {
105     FuncOp funcOp = getFunction();
106     MLIRContext *ctx = funcOp.getContext();
107     RewritePatternSet patterns(ctx);
108     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
109     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
110       signalPassFailure();
111   }
112 };
113 } // namespace
114 
115 void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
116     RewritePatternSet &patterns) {
117   MLIRContext *ctx = patterns.getContext();
118   patterns
119       .insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
120               AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>,
121               DimOfIterArgFolder<tensor::DimOp>,
122               DimOfIterArgFolder<memref::DimOp>>(ctx);
123 }
124 
125 std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
126   return std::make_unique<SCFForLoopCanonicalization>();
127 }
128