1 //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains cross-dialect canonicalization patterns that cannot be
10 // actual canonicalization patterns due to undesired additional dependencies.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/Passes.h"
18 #include "mlir/Dialect/SCF/SCF.h"
19 #include "mlir/Dialect/SCF/Transforms.h"
20 #include "mlir/Dialect/Tensor/IR/Tensor.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 using namespace mlir;
26 using namespace mlir::scf;
27 
28 /// A simple, conservative analysis to determine if the loop is shape
29 /// conserving. I.e., the type of the arg-th yielded value is the same as the
30 /// type of the corresponding basic block argument of the loop.
31 /// Note: This function handles only simple cases. Expand as needed.
32 static bool isShapePreserving(ForOp forOp, int64_t arg) {
33   auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
34   assert(arg < static_cast<int64_t>(yieldOp.results().size()) &&
35          "arg is out of bounds");
36   Value value = yieldOp.results()[arg];
37   while (value) {
38     if (value == forOp.getRegionIterArgs()[arg])
39       return true;
40     OpResult opResult = value.dyn_cast<OpResult>();
41     if (!opResult)
42       return false;
43 
44     using tensor::InsertSliceOp;
45     value =
46         llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
47             .template Case<InsertSliceOp>(
48                 [&](InsertSliceOp op) { return op.dest(); })
49             .template Case<ForOp>([&](ForOp forOp) {
50               return isShapePreserving(forOp, opResult.getResultNumber())
51                          ? forOp.getIterOperands()[opResult.getResultNumber()]
52                          : Value();
53             })
54             .Default([&](auto op) { return Value(); });
55   }
56   return false;
57 }
58 
59 namespace {
60 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
61 ///
62 /// ```
63 /// %0 = ... : tensor<?x?xf32>
64 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
65 ///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
66 ///   ...
67 /// }
68 /// ```
69 ///
70 /// is folded to:
71 ///
72 /// ```
73 /// %0 = ... : tensor<?x?xf32>
74 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
75 ///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
76 ///   ...
77 /// }
78 /// ```
79 ///
80 /// Note: Dim ops are folded only if it can be proven that the runtime type of
81 /// the iter arg does not change with loop iterations.
82 template <typename OpTy>
83 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
84   using OpRewritePattern<OpTy>::OpRewritePattern;
85 
86   LogicalResult matchAndRewrite(OpTy dimOp,
87                                 PatternRewriter &rewriter) const override {
88     auto blockArg = dimOp.source().template dyn_cast<BlockArgument>();
89     if (!blockArg)
90       return failure();
91     auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
92     if (!forOp)
93       return failure();
94     if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
95       return failure();
96 
97     Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
98     rewriter.updateRootInPlace(
99         dimOp, [&]() { dimOp.sourceMutable().assign(initArg); });
100 
101     return success();
102   };
103 };
104 
105 /// Fold dim ops of loop results to dim ops of their respective init args. E.g.:
106 ///
107 /// ```
108 /// %0 = ... : tensor<?x?xf32>
109 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
110 ///   ...
111 /// }
112 /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32>
113 /// ```
114 ///
115 /// is folded to:
116 ///
117 /// ```
118 /// %0 = ... : tensor<?x?xf32>
119 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
120 ///   ...
121 /// }
122 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
123 /// ```
124 ///
125 /// Note: Dim ops are folded only if it can be proven that the runtime type of
126 /// the iter arg does not change with loop iterations.
127 template <typename OpTy>
128 struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
129   using OpRewritePattern<OpTy>::OpRewritePattern;
130 
131   LogicalResult matchAndRewrite(OpTy dimOp,
132                                 PatternRewriter &rewriter) const override {
133     auto forOp = dimOp.source().template getDefiningOp<scf::ForOp>();
134     if (!forOp)
135       return failure();
136     auto opResult = dimOp.source().template cast<OpResult>();
137     unsigned resultNumber = opResult.getResultNumber();
138     if (!isShapePreserving(forOp, resultNumber))
139       return failure();
140     rewriter.updateRootInPlace(dimOp, [&](){
141       dimOp.sourceMutable().assign(forOp.getIterOperands()[resultNumber]);
142     });
143     return success();
144   }
145 };
146 
147 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
148 /// and scf.parallel loops with a known range.
149 template <typename OpTy, bool IsMin>
150 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
151   using OpRewritePattern<OpTy>::OpRewritePattern;
152 
153   LogicalResult matchAndRewrite(OpTy op,
154                                 PatternRewriter &rewriter) const override {
155     auto loopMatcher = [](Value iv, Value &lb, Value &ub, Value &step) {
156       if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
157         lb = forOp.lowerBound();
158         ub = forOp.upperBound();
159         step = forOp.step();
160         return success();
161       }
162       if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) {
163         for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
164           if (parOp.getInductionVars()[idx] == iv) {
165             lb = parOp.lowerBound()[idx];
166             ub = parOp.upperBound()[idx];
167             step = parOp.step()[idx];
168             return success();
169           }
170         }
171         return failure();
172       }
173       return failure();
174     };
175 
176     return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(),
177                                            op.operands(), IsMin, loopMatcher);
178   }
179 };
180 
181 struct SCFForLoopCanonicalization
182     : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
183   void runOnFunction() override {
184     FuncOp funcOp = getFunction();
185     MLIRContext *ctx = funcOp.getContext();
186     RewritePatternSet patterns(ctx);
187     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
188     if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns))))
189       signalPassFailure();
190   }
191 };
192 } // namespace
193 
194 void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
195     RewritePatternSet &patterns) {
196   MLIRContext *ctx = patterns.getContext();
197   patterns
198       .insert<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
199               AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>,
200               DimOfIterArgFolder<tensor::DimOp>,
201               DimOfIterArgFolder<memref::DimOp>,
202               DimOfLoopResultFolder<tensor::DimOp>,
203               DimOfLoopResultFolder<memref::DimOp>>(ctx);
204 }
205 
206 std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
207   return std::make_unique<SCFForLoopCanonicalization>();
208 }
209