1 //===- TestConstantFold.cpp - Pass to test constant folding ---------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/Dialect/AffineOps/AffineOps.h"
19 #include "mlir/Dialect/StandardOps/Ops.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Function.h"
22 #include "mlir/Pass/Pass.h"
23 #include "mlir/Transforms/FoldUtils.h"
24 #include "mlir/Transforms/Passes.h"
25 #include "mlir/Transforms/Utils.h"
26 
27 using namespace mlir;
28 
29 namespace {
30 /// Simple constant folding pass.
31 struct TestConstantFold : public FunctionPass<TestConstantFold> {
32   // All constants in the function post folding.
33   SmallVector<Operation *, 8> existingConstants;
34 
35   void foldOperation(Operation *op, OperationFolder &helper);
36   void runOnFunction() override;
37 };
38 } // end anonymous namespace
39 
40 void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
41   auto processGeneratedConstants = [this](Operation *op) {
42     existingConstants.push_back(op);
43   };
44 
45   // Attempt to fold the specified operation, including handling unused or
46   // duplicated constants.
47   (void)helper.tryToFold(op, processGeneratedConstants);
48 }
49 
50 // For now, we do a simple top-down pass over a function folding constants.  We
51 // don't handle conditional control flow, block arguments, folding conditional
52 // branches, or anything else fancy.
53 void TestConstantFold::runOnFunction() {
54   existingConstants.clear();
55 
56   // Collect and fold the operations within the function.
57   SmallVector<Operation *, 8> ops;
58   getFunction().walk([&](Operation *op) { ops.push_back(op); });
59 
60   // Fold the constants in reverse so that the last generated constants from
61   // folding are at the beginning. This creates somewhat of a linear ordering to
62   // the newly generated constants that matches the operation order and improves
63   // the readability of test cases.
64   OperationFolder helper;
65   for (Operation *op : llvm::reverse(ops))
66     foldOperation(op, helper);
67 
68   // By the time we are done, we may have simplified a bunch of code, leaving
69   // around dead constants.  Check for them now and remove them.
70   for (auto *cst : existingConstants) {
71     if (cst->use_empty())
72       cst->erase();
73   }
74 }
75 
76 /// Creates a constant folding pass.
77 std::unique_ptr<FunctionPassBase> mlir::createTestConstantFoldPass() {
78   return std::make_unique<TestConstantFold>();
79 }
80 
81 static PassRegistration<TestConstantFold>
82     pass("test-constant-fold", "Test operation constant folding");
83