1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "PassDetail.h"
18 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
19 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/FoldUtils.h"
25 #include "mlir/Transforms/Passes.h"
26 
27 using namespace mlir;
28 using namespace mlir::dataflow;
29 
30 //===----------------------------------------------------------------------===//
31 // SCCP Rewrites
32 //===----------------------------------------------------------------------===//
33 
34 /// Replace the given value with a constant if the corresponding lattice
35 /// represents a constant. Returns success if the value was replaced, failure
36 /// otherwise.
replaceWithConstant(DataFlowSolver & solver,OpBuilder & builder,OperationFolder & folder,Value value)37 static LogicalResult replaceWithConstant(DataFlowSolver &solver,
38                                          OpBuilder &builder,
39                                          OperationFolder &folder, Value value) {
40   auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
41   if (!lattice || lattice->isUninitialized())
42     return failure();
43   const ConstantValue &latticeValue = lattice->getValue();
44   if (!latticeValue.getConstantValue())
45     return failure();
46 
47   // Attempt to materialize a constant for the given value.
48   Dialect *dialect = latticeValue.getConstantDialect();
49   Value constant = folder.getOrCreateConstant(builder, dialect,
50                                               latticeValue.getConstantValue(),
51                                               value.getType(), value.getLoc());
52   if (!constant)
53     return failure();
54 
55   value.replaceAllUsesWith(constant);
56   return success();
57 }
58 
59 /// Rewrite the given regions using the computing analysis. This replaces the
60 /// uses of all values that have been computed to be constant, and erases as
61 /// many newly dead operations.
rewrite(DataFlowSolver & solver,MLIRContext * context,MutableArrayRef<Region> initialRegions)62 static void rewrite(DataFlowSolver &solver, MLIRContext *context,
63                     MutableArrayRef<Region> initialRegions) {
64   SmallVector<Block *> worklist;
65   auto addToWorklist = [&](MutableArrayRef<Region> regions) {
66     for (Region &region : regions)
67       for (Block &block : llvm::reverse(region))
68         worklist.push_back(&block);
69   };
70 
71   // An operation folder used to create and unique constants.
72   OperationFolder folder(context);
73   OpBuilder builder(context);
74 
75   addToWorklist(initialRegions);
76   while (!worklist.empty()) {
77     Block *block = worklist.pop_back_val();
78 
79     for (Operation &op : llvm::make_early_inc_range(*block)) {
80       builder.setInsertionPoint(&op);
81 
82       // Replace any result with constants.
83       bool replacedAll = op.getNumResults() != 0;
84       for (Value res : op.getResults())
85         replacedAll &=
86             succeeded(replaceWithConstant(solver, builder, folder, res));
87 
88       // If all of the results of the operation were replaced, try to erase
89       // the operation completely.
90       if (replacedAll && wouldOpBeTriviallyDead(&op)) {
91         assert(op.use_empty() && "expected all uses to be replaced");
92         op.erase();
93         continue;
94       }
95 
96       // Add any the regions of this operation to the worklist.
97       addToWorklist(op.getRegions());
98     }
99 
100     // Replace any block arguments with constants.
101     builder.setInsertionPointToStart(block);
102     for (BlockArgument arg : block->getArguments())
103       (void)replaceWithConstant(solver, builder, folder, arg);
104   }
105 }
106 
107 //===----------------------------------------------------------------------===//
108 // SCCP Pass
109 //===----------------------------------------------------------------------===//
110 
111 namespace {
112 struct SCCP : public SCCPBase<SCCP> {
113   void runOnOperation() override;
114 };
115 } // namespace
116 
runOnOperation()117 void SCCP::runOnOperation() {
118   Operation *op = getOperation();
119 
120   DataFlowSolver solver;
121   solver.load<DeadCodeAnalysis>();
122   solver.load<SparseConstantPropagation>();
123   if (failed(solver.initializeAndRun(op)))
124     return signalPassFailure();
125   rewrite(solver, op->getContext(), op->getRegions());
126 }
127 
createSCCPPass()128 std::unique_ptr<Pass> mlir::createSCCPPass() {
129   return std::make_unique<SCCP>();
130 }
131