1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This transformation pass performs a sparse conditional constant propagation
10 // in MLIR. It identifies values known to be constant, propagates that
11 // information throughout the IR, and replaces them. This is done with an
12 // optimistic dataflow analysis that assumes that all values are constant until
13 // proven otherwise.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "PassDetail.h"
18 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
19 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Dialect.h"
22 #include "mlir/Interfaces/SideEffectInterfaces.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Transforms/FoldUtils.h"
25 #include "mlir/Transforms/Passes.h"
26
27 using namespace mlir;
28 using namespace mlir::dataflow;
29
30 //===----------------------------------------------------------------------===//
31 // SCCP Rewrites
32 //===----------------------------------------------------------------------===//
33
34 /// Replace the given value with a constant if the corresponding lattice
35 /// represents a constant. Returns success if the value was replaced, failure
36 /// otherwise.
replaceWithConstant(DataFlowSolver & solver,OpBuilder & builder,OperationFolder & folder,Value value)37 static LogicalResult replaceWithConstant(DataFlowSolver &solver,
38 OpBuilder &builder,
39 OperationFolder &folder, Value value) {
40 auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
41 if (!lattice || lattice->isUninitialized())
42 return failure();
43 const ConstantValue &latticeValue = lattice->getValue();
44 if (!latticeValue.getConstantValue())
45 return failure();
46
47 // Attempt to materialize a constant for the given value.
48 Dialect *dialect = latticeValue.getConstantDialect();
49 Value constant = folder.getOrCreateConstant(builder, dialect,
50 latticeValue.getConstantValue(),
51 value.getType(), value.getLoc());
52 if (!constant)
53 return failure();
54
55 value.replaceAllUsesWith(constant);
56 return success();
57 }
58
59 /// Rewrite the given regions using the computing analysis. This replaces the
60 /// uses of all values that have been computed to be constant, and erases as
61 /// many newly dead operations.
rewrite(DataFlowSolver & solver,MLIRContext * context,MutableArrayRef<Region> initialRegions)62 static void rewrite(DataFlowSolver &solver, MLIRContext *context,
63 MutableArrayRef<Region> initialRegions) {
64 SmallVector<Block *> worklist;
65 auto addToWorklist = [&](MutableArrayRef<Region> regions) {
66 for (Region ®ion : regions)
67 for (Block &block : llvm::reverse(region))
68 worklist.push_back(&block);
69 };
70
71 // An operation folder used to create and unique constants.
72 OperationFolder folder(context);
73 OpBuilder builder(context);
74
75 addToWorklist(initialRegions);
76 while (!worklist.empty()) {
77 Block *block = worklist.pop_back_val();
78
79 for (Operation &op : llvm::make_early_inc_range(*block)) {
80 builder.setInsertionPoint(&op);
81
82 // Replace any result with constants.
83 bool replacedAll = op.getNumResults() != 0;
84 for (Value res : op.getResults())
85 replacedAll &=
86 succeeded(replaceWithConstant(solver, builder, folder, res));
87
88 // If all of the results of the operation were replaced, try to erase
89 // the operation completely.
90 if (replacedAll && wouldOpBeTriviallyDead(&op)) {
91 assert(op.use_empty() && "expected all uses to be replaced");
92 op.erase();
93 continue;
94 }
95
96 // Add any the regions of this operation to the worklist.
97 addToWorklist(op.getRegions());
98 }
99
100 // Replace any block arguments with constants.
101 builder.setInsertionPointToStart(block);
102 for (BlockArgument arg : block->getArguments())
103 (void)replaceWithConstant(solver, builder, folder, arg);
104 }
105 }
106
107 //===----------------------------------------------------------------------===//
108 // SCCP Pass
109 //===----------------------------------------------------------------------===//
110
111 namespace {
112 struct SCCP : public SCCPBase<SCCP> {
113 void runOnOperation() override;
114 };
115 } // namespace
116
runOnOperation()117 void SCCP::runOnOperation() {
118 Operation *op = getOperation();
119
120 DataFlowSolver solver;
121 solver.load<DeadCodeAnalysis>();
122 solver.load<SparseConstantPropagation>();
123 if (failed(solver.initializeAndRun(op)))
124 return signalPassFailure();
125 rewrite(solver, op->getContext(), op->getRegions());
126 }
127
createSCCPPass()128 std::unique_ptr<Pass> mlir::createSCCPPass() {
129 return std::make_unique<SCCP>();
130 }
131