1 //===- LoopRangeFolding.cpp - Code to perform loop range folding-----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements loop range folding.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
15 #include "mlir/Dialect/SCF/IR/SCF.h"
16 #include "mlir/Dialect/SCF/Transforms/Passes.h"
17 #include "mlir/Dialect/SCF/Transforms/Transforms.h"
18 #include "mlir/Dialect/SCF/Utils/Utils.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20
21 using namespace mlir;
22 using namespace mlir::scf;
23
24 namespace {
25 struct ForLoopRangeFolding
26 : public SCFForLoopRangeFoldingBase<ForLoopRangeFolding> {
27 void runOnOperation() override;
28 };
29 } // namespace
30
runOnOperation()31 void ForLoopRangeFolding::runOnOperation() {
32 getOperation()->walk([&](ForOp op) {
33 Value indVar = op.getInductionVar();
34
35 auto canBeFolded = [&](Value value) {
36 return op.isDefinedOutsideOfLoop(value) || value == indVar;
37 };
38
39 // Fold until a fixed point is reached
40 while (true) {
41
42 // If the induction variable is used more than once, we can't fold its
43 // arith ops into the loop range
44 if (!indVar.hasOneUse())
45 break;
46
47 Operation *user = *indVar.getUsers().begin();
48 if (!isa<arith::AddIOp, arith::MulIOp>(user))
49 break;
50
51 if (!llvm::all_of(user->getOperands(), canBeFolded))
52 break;
53
54 OpBuilder b(op);
55 BlockAndValueMapping lbMap;
56 lbMap.map(indVar, op.getLowerBound());
57 BlockAndValueMapping ubMap;
58 ubMap.map(indVar, op.getUpperBound());
59 BlockAndValueMapping stepMap;
60 stepMap.map(indVar, op.getStep());
61
62 if (isa<arith::AddIOp>(user)) {
63 Operation *lbFold = b.clone(*user, lbMap);
64 Operation *ubFold = b.clone(*user, ubMap);
65
66 op.setLowerBound(lbFold->getResult(0));
67 op.setUpperBound(ubFold->getResult(0));
68
69 } else if (isa<arith::MulIOp>(user)) {
70 Operation *ubFold = b.clone(*user, ubMap);
71 Operation *stepFold = b.clone(*user, stepMap);
72
73 op.setUpperBound(ubFold->getResult(0));
74 op.setStep(stepFold->getResult(0));
75 }
76
77 ValueRange wrapIndvar(indVar);
78 user->replaceAllUsesWith(wrapIndvar);
79 user->erase();
80 }
81 });
82 }
83
createForLoopRangeFoldingPass()84 std::unique_ptr<Pass> mlir::createForLoopRangeFoldingPass() {
85 return std::make_unique<ForLoopRangeFolding>();
86 }
87