1 //===- LoopCanonicalization.cpp - Cross-dialect canonicalization patterns -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains cross-dialect canonicalization patterns that cannot be
10 // actual canonicalization patterns due to undesired additional dependencies.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/Affine/IR/AffineOps.h"
16 #include "mlir/Dialect/MemRef/IR/MemRef.h"
17 #include "mlir/Dialect/SCF/IR/SCF.h"
18 #include "mlir/Dialect/SCF/Transforms/Passes.h"
19 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
20 #include "mlir/Dialect/SCF/Utils/AffineCanonicalizationUtils.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace mlir;
27 using namespace mlir::scf;
28 
29 /// A simple, conservative analysis to determine if the loop is shape
30 /// conserving. I.e., the type of the arg-th yielded value is the same as the
31 /// type of the corresponding basic block argument of the loop.
32 /// Note: This function handles only simple cases. Expand as needed.
isShapePreserving(ForOp forOp,int64_t arg)33 static bool isShapePreserving(ForOp forOp, int64_t arg) {
34   auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
35   assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) &&
36          "arg is out of bounds");
37   Value value = yieldOp.getResults()[arg];
38   while (value) {
39     if (value == forOp.getRegionIterArgs()[arg])
40       return true;
41     OpResult opResult = value.dyn_cast<OpResult>();
42     if (!opResult)
43       return false;
44 
45     using tensor::InsertSliceOp;
46     value =
47         llvm::TypeSwitch<Operation *, Value>(opResult.getOwner())
48             .template Case<InsertSliceOp>(
49                 [&](InsertSliceOp op) { return op.getDest(); })
50             .template Case<ForOp>([&](ForOp forOp) {
51               return isShapePreserving(forOp, opResult.getResultNumber())
52                          ? forOp.getIterOperands()[opResult.getResultNumber()]
53                          : Value();
54             })
55             .Default([&](auto op) { return Value(); });
56   }
57   return false;
58 }
59 
60 namespace {
61 /// Fold dim ops of iter_args to dim ops of their respective init args. E.g.:
62 ///
63 /// ```
64 /// %0 = ... : tensor<?x?xf32>
65 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
66 ///   %1 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
67 ///   ...
68 /// }
69 /// ```
70 ///
71 /// is folded to:
72 ///
73 /// ```
74 /// %0 = ... : tensor<?x?xf32>
75 /// scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
76 ///   %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
77 ///   ...
78 /// }
79 /// ```
80 ///
81 /// Note: Dim ops are folded only if it can be proven that the runtime type of
82 /// the iter arg does not change with loop iterations.
83 template <typename OpTy>
84 struct DimOfIterArgFolder : public OpRewritePattern<OpTy> {
85   using OpRewritePattern<OpTy>::OpRewritePattern;
86 
matchAndRewrite__anon1ba6c2270411::DimOfIterArgFolder87   LogicalResult matchAndRewrite(OpTy dimOp,
88                                 PatternRewriter &rewriter) const override {
89     auto blockArg = dimOp.getSource().template dyn_cast<BlockArgument>();
90     if (!blockArg)
91       return failure();
92     auto forOp = dyn_cast<ForOp>(blockArg.getParentBlock()->getParentOp());
93     if (!forOp)
94       return failure();
95     if (!isShapePreserving(forOp, blockArg.getArgNumber() - 1))
96       return failure();
97 
98     Value initArg = forOp.getOpOperandForRegionIterArg(blockArg).get();
99     rewriter.updateRootInPlace(
100         dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
101 
102     return success();
103   };
104 };
105 
106 /// Fold dim ops of loop results to dim ops of their respective init args. E.g.:
107 ///
108 /// ```
109 /// %0 = ... : tensor<?x?xf32>
110 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
111 ///   ...
112 /// }
113 /// %1 = tensor.dim %r, %c0 : tensor<?x?xf32>
114 /// ```
115 ///
116 /// is folded to:
117 ///
118 /// ```
119 /// %0 = ... : tensor<?x?xf32>
120 /// %r = scf.for ... iter_args(%arg0 = %0) -> (tensor<?x?xf32>) {
121 ///   ...
122 /// }
123 /// %1 = tensor.dim %0, %c0 : tensor<?x?xf32>
124 /// ```
125 ///
126 /// Note: Dim ops are folded only if it can be proven that the runtime type of
127 /// the iter arg does not change with loop iterations.
128 template <typename OpTy>
129 struct DimOfLoopResultFolder : public OpRewritePattern<OpTy> {
130   using OpRewritePattern<OpTy>::OpRewritePattern;
131 
matchAndRewrite__anon1ba6c2270411::DimOfLoopResultFolder132   LogicalResult matchAndRewrite(OpTy dimOp,
133                                 PatternRewriter &rewriter) const override {
134     auto forOp = dimOp.getSource().template getDefiningOp<scf::ForOp>();
135     if (!forOp)
136       return failure();
137     auto opResult = dimOp.getSource().template cast<OpResult>();
138     unsigned resultNumber = opResult.getResultNumber();
139     if (!isShapePreserving(forOp, resultNumber))
140       return failure();
141     rewriter.updateRootInPlace(dimOp, [&]() {
142       dimOp.getSourceMutable().assign(forOp.getIterOperands()[resultNumber]);
143     });
144     return success();
145   }
146 };
147 
148 /// Canonicalize AffineMinOp/AffineMaxOp operations in the context of scf.for
149 /// and scf.parallel loops with a known range.
150 template <typename OpTy, bool IsMin>
151 struct AffineOpSCFCanonicalizationPattern : public OpRewritePattern<OpTy> {
152   using OpRewritePattern<OpTy>::OpRewritePattern;
153 
matchAndRewrite__anon1ba6c2270411::AffineOpSCFCanonicalizationPattern154   LogicalResult matchAndRewrite(OpTy op,
155                                 PatternRewriter &rewriter) const override {
156     auto loopMatcher = [](Value iv, OpFoldResult &lb, OpFoldResult &ub,
157                           OpFoldResult &step) {
158       if (scf::ForOp forOp = scf::getForInductionVarOwner(iv)) {
159         lb = forOp.getLowerBound();
160         ub = forOp.getUpperBound();
161         step = forOp.getStep();
162         return success();
163       }
164       if (scf::ParallelOp parOp = scf::getParallelForInductionVarOwner(iv)) {
165         for (unsigned idx = 0; idx < parOp.getNumLoops(); ++idx) {
166           if (parOp.getInductionVars()[idx] == iv) {
167             lb = parOp.getLowerBound()[idx];
168             ub = parOp.getUpperBound()[idx];
169             step = parOp.getStep()[idx];
170             return success();
171           }
172         }
173         return failure();
174       }
175       if (scf::ForeachThreadOp foreachThreadOp =
176               scf::getForeachThreadOpThreadIndexOwner(iv)) {
177         for (int64_t idx = 0; idx < foreachThreadOp.getRank(); ++idx) {
178           if (foreachThreadOp.getThreadIndices()[idx] == iv) {
179             lb = OpBuilder(iv.getContext()).getIndexAttr(0);
180             ub = foreachThreadOp.getNumThreads()[idx];
181             step = OpBuilder(iv.getContext()).getIndexAttr(1);
182             return success();
183           }
184         }
185         return failure();
186       }
187       return failure();
188     };
189 
190     return scf::canonicalizeMinMaxOpInLoop(rewriter, op, op.getAffineMap(),
191                                            op.operands(), IsMin, loopMatcher);
192   }
193 };
194 
195 struct SCFForLoopCanonicalization
196     : public SCFForLoopCanonicalizationBase<SCFForLoopCanonicalization> {
runOnOperation__anon1ba6c2270411::SCFForLoopCanonicalization197   void runOnOperation() override {
198     auto *parentOp = getOperation();
199     MLIRContext *ctx = parentOp->getContext();
200     RewritePatternSet patterns(ctx);
201     scf::populateSCFForLoopCanonicalizationPatterns(patterns);
202     if (failed(applyPatternsAndFoldGreedily(parentOp, std::move(patterns))))
203       signalPassFailure();
204   }
205 };
206 } // namespace
207 
populateSCFForLoopCanonicalizationPatterns(RewritePatternSet & patterns)208 void mlir::scf::populateSCFForLoopCanonicalizationPatterns(
209     RewritePatternSet &patterns) {
210   MLIRContext *ctx = patterns.getContext();
211   patterns
212       .add<AffineOpSCFCanonicalizationPattern<AffineMinOp, /*IsMin=*/true>,
213            AffineOpSCFCanonicalizationPattern<AffineMaxOp, /*IsMin=*/false>,
214            DimOfIterArgFolder<tensor::DimOp>, DimOfIterArgFolder<memref::DimOp>,
215            DimOfLoopResultFolder<tensor::DimOp>,
216            DimOfLoopResultFolder<memref::DimOp>>(ctx);
217 }
218 
createSCFForLoopCanonicalizationPass()219 std::unique_ptr<Pass> mlir::createSCFForLoopCanonicalizationPass() {
220   return std::make_unique<SCFForLoopCanonicalization>();
221 }
222