1152d29ccSRiver Riddle //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2152d29ccSRiver Riddle //
3152d29ccSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4152d29ccSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5152d29ccSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6152d29ccSRiver Riddle //
7152d29ccSRiver Riddle //===----------------------------------------------------------------------===//
8152d29ccSRiver Riddle //
9152d29ccSRiver Riddle // This transformation pass performs a sparse conditional constant propagation
10152d29ccSRiver Riddle // in MLIR. It identifies values known to be constant, propagates that
11152d29ccSRiver Riddle // information throughout the IR, and replaces them. This is done with an
1241b09f4eSKazuaki Ishizaki // optimistic dataflow analysis that assumes that all values are constant until
13152d29ccSRiver Riddle // proven otherwise.
14152d29ccSRiver Riddle //
15152d29ccSRiver Riddle //===----------------------------------------------------------------------===//
16152d29ccSRiver Riddle
17152d29ccSRiver Riddle #include "PassDetail.h"
189432fbfeSMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
199432fbfeSMogball #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
20152d29ccSRiver Riddle #include "mlir/IR/Builders.h"
21152d29ccSRiver Riddle #include "mlir/IR/Dialect.h"
22eb623ae8SStephen Neuendorffer #include "mlir/Interfaces/SideEffectInterfaces.h"
23152d29ccSRiver Riddle #include "mlir/Pass/Pass.h"
24152d29ccSRiver Riddle #include "mlir/Transforms/FoldUtils.h"
25152d29ccSRiver Riddle #include "mlir/Transforms/Passes.h"
26152d29ccSRiver Riddle
27152d29ccSRiver Riddle using namespace mlir;
289432fbfeSMogball using namespace mlir::dataflow;
29a90151d6SRiver Riddle
30d07c90e3SRiver Riddle //===----------------------------------------------------------------------===//
31d07c90e3SRiver Riddle // SCCP Rewrites
32d07c90e3SRiver Riddle //===----------------------------------------------------------------------===//
33a90151d6SRiver Riddle
34152d29ccSRiver Riddle /// Replace the given value with a constant if the corresponding lattice
35152d29ccSRiver Riddle /// represents a constant. Returns success if the value was replaced, failure
36152d29ccSRiver Riddle /// otherwise.
replaceWithConstant(DataFlowSolver & solver,OpBuilder & builder,OperationFolder & folder,Value value)379432fbfeSMogball static LogicalResult replaceWithConstant(DataFlowSolver &solver,
38d07c90e3SRiver Riddle OpBuilder &builder,
39d07c90e3SRiver Riddle OperationFolder &folder, Value value) {
409432fbfeSMogball auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
41*ab701975SMogball if (!lattice || lattice->isUninitialized())
42d07c90e3SRiver Riddle return failure();
439432fbfeSMogball const ConstantValue &latticeValue = lattice->getValue();
449432fbfeSMogball if (!latticeValue.getConstantValue())
45d07c90e3SRiver Riddle return failure();
46152d29ccSRiver Riddle
47d07c90e3SRiver Riddle // Attempt to materialize a constant for the given value.
489432fbfeSMogball Dialect *dialect = latticeValue.getConstantDialect();
499432fbfeSMogball Value constant = folder.getOrCreateConstant(builder, dialect,
509432fbfeSMogball latticeValue.getConstantValue(),
519432fbfeSMogball value.getType(), value.getLoc());
52d07c90e3SRiver Riddle if (!constant)
53d07c90e3SRiver Riddle return failure();
54d07c90e3SRiver Riddle
55d07c90e3SRiver Riddle value.replaceAllUsesWith(constant);
56d07c90e3SRiver Riddle return success();
572eda87dfSRiver Riddle }
582eda87dfSRiver Riddle
59d07c90e3SRiver Riddle /// Rewrite the given regions using the computing analysis. This replaces the
60d07c90e3SRiver Riddle /// uses of all values that have been computed to be constant, and erases as
61d07c90e3SRiver Riddle /// many newly dead operations.
rewrite(DataFlowSolver & solver,MLIRContext * context,MutableArrayRef<Region> initialRegions)629432fbfeSMogball static void rewrite(DataFlowSolver &solver, MLIRContext *context,
63152d29ccSRiver Riddle MutableArrayRef<Region> initialRegions) {
64d07c90e3SRiver Riddle SmallVector<Block *> worklist;
65152d29ccSRiver Riddle auto addToWorklist = [&](MutableArrayRef<Region> regions) {
66152d29ccSRiver Riddle for (Region ®ion : regions)
67d07c90e3SRiver Riddle for (Block &block : llvm::reverse(region))
68152d29ccSRiver Riddle worklist.push_back(&block);
69152d29ccSRiver Riddle };
70152d29ccSRiver Riddle
71152d29ccSRiver Riddle // An operation folder used to create and unique constants.
72152d29ccSRiver Riddle OperationFolder folder(context);
73152d29ccSRiver Riddle OpBuilder builder(context);
74152d29ccSRiver Riddle
75152d29ccSRiver Riddle addToWorklist(initialRegions);
76152d29ccSRiver Riddle while (!worklist.empty()) {
77152d29ccSRiver Riddle Block *block = worklist.pop_back_val();
78152d29ccSRiver Riddle
79152d29ccSRiver Riddle for (Operation &op : llvm::make_early_inc_range(*block)) {
80152d29ccSRiver Riddle builder.setInsertionPoint(&op);
81152d29ccSRiver Riddle
82152d29ccSRiver Riddle // Replace any result with constants.
83152d29ccSRiver Riddle bool replacedAll = op.getNumResults() != 0;
84152d29ccSRiver Riddle for (Value res : op.getResults())
85d07c90e3SRiver Riddle replacedAll &=
869432fbfeSMogball succeeded(replaceWithConstant(solver, builder, folder, res));
87152d29ccSRiver Riddle
88152d29ccSRiver Riddle // If all of the results of the operation were replaced, try to erase
89152d29ccSRiver Riddle // the operation completely.
90152d29ccSRiver Riddle if (replacedAll && wouldOpBeTriviallyDead(&op)) {
91152d29ccSRiver Riddle assert(op.use_empty() && "expected all uses to be replaced");
92152d29ccSRiver Riddle op.erase();
93152d29ccSRiver Riddle continue;
94152d29ccSRiver Riddle }
95152d29ccSRiver Riddle
96152d29ccSRiver Riddle // Add any the regions of this operation to the worklist.
97152d29ccSRiver Riddle addToWorklist(op.getRegions());
98152d29ccSRiver Riddle }
99d07c90e3SRiver Riddle
100d07c90e3SRiver Riddle // Replace any block arguments with constants.
101d07c90e3SRiver Riddle builder.setInsertionPointToStart(block);
102d07c90e3SRiver Riddle for (BlockArgument arg : block->getArguments())
1039432fbfeSMogball (void)replaceWithConstant(solver, builder, folder, arg);
104152d29ccSRiver Riddle }
105152d29ccSRiver Riddle }
106152d29ccSRiver Riddle
107152d29ccSRiver Riddle //===----------------------------------------------------------------------===//
108152d29ccSRiver Riddle // SCCP Pass
109152d29ccSRiver Riddle //===----------------------------------------------------------------------===//
110152d29ccSRiver Riddle
111152d29ccSRiver Riddle namespace {
112152d29ccSRiver Riddle struct SCCP : public SCCPBase<SCCP> {
113152d29ccSRiver Riddle void runOnOperation() override;
114152d29ccSRiver Riddle };
115be0a7e9fSMehdi Amini } // namespace
116152d29ccSRiver Riddle
runOnOperation()117152d29ccSRiver Riddle void SCCP::runOnOperation() {
118152d29ccSRiver Riddle Operation *op = getOperation();
119152d29ccSRiver Riddle
1209432fbfeSMogball DataFlowSolver solver;
1219432fbfeSMogball solver.load<DeadCodeAnalysis>();
1229432fbfeSMogball solver.load<SparseConstantPropagation>();
1239432fbfeSMogball if (failed(solver.initializeAndRun(op)))
1249432fbfeSMogball return signalPassFailure();
1259432fbfeSMogball rewrite(solver, op->getContext(), op->getRegions());
126152d29ccSRiver Riddle }
127152d29ccSRiver Riddle
createSCCPPass()128152d29ccSRiver Riddle std::unique_ptr<Pass> mlir::createSCCPPass() {
129152d29ccSRiver Riddle return std::make_unique<SCCP>();
130152d29ccSRiver Riddle }
131