1152d29ccSRiver Riddle //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===//
2152d29ccSRiver Riddle //
3152d29ccSRiver Riddle // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4152d29ccSRiver Riddle // See https://llvm.org/LICENSE.txt for license information.
5152d29ccSRiver Riddle // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6152d29ccSRiver Riddle //
7152d29ccSRiver Riddle //===----------------------------------------------------------------------===//
8152d29ccSRiver Riddle //
9152d29ccSRiver Riddle // This transformation pass performs a sparse conditional constant propagation
10152d29ccSRiver Riddle // in MLIR. It identifies values known to be constant, propagates that
11152d29ccSRiver Riddle // information throughout the IR, and replaces them. This is done with an
1241b09f4eSKazuaki Ishizaki // optimistic dataflow analysis that assumes that all values are constant until
13152d29ccSRiver Riddle // proven otherwise.
14152d29ccSRiver Riddle //
15152d29ccSRiver Riddle //===----------------------------------------------------------------------===//
16152d29ccSRiver Riddle 
17152d29ccSRiver Riddle #include "PassDetail.h"
189432fbfeSMogball #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
199432fbfeSMogball #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h"
20152d29ccSRiver Riddle #include "mlir/IR/Builders.h"
21152d29ccSRiver Riddle #include "mlir/IR/Dialect.h"
22eb623ae8SStephen Neuendorffer #include "mlir/Interfaces/SideEffectInterfaces.h"
23152d29ccSRiver Riddle #include "mlir/Pass/Pass.h"
24152d29ccSRiver Riddle #include "mlir/Transforms/FoldUtils.h"
25152d29ccSRiver Riddle #include "mlir/Transforms/Passes.h"
26152d29ccSRiver Riddle 
27152d29ccSRiver Riddle using namespace mlir;
289432fbfeSMogball using namespace mlir::dataflow;
29a90151d6SRiver Riddle 
30d07c90e3SRiver Riddle //===----------------------------------------------------------------------===//
31d07c90e3SRiver Riddle // SCCP Rewrites
32d07c90e3SRiver Riddle //===----------------------------------------------------------------------===//
33a90151d6SRiver Riddle 
34152d29ccSRiver Riddle /// Replace the given value with a constant if the corresponding lattice
35152d29ccSRiver Riddle /// represents a constant. Returns success if the value was replaced, failure
36152d29ccSRiver Riddle /// otherwise.
replaceWithConstant(DataFlowSolver & solver,OpBuilder & builder,OperationFolder & folder,Value value)379432fbfeSMogball static LogicalResult replaceWithConstant(DataFlowSolver &solver,
38d07c90e3SRiver Riddle                                          OpBuilder &builder,
39d07c90e3SRiver Riddle                                          OperationFolder &folder, Value value) {
409432fbfeSMogball   auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value);
41*ab701975SMogball   if (!lattice || lattice->isUninitialized())
42d07c90e3SRiver Riddle     return failure();
439432fbfeSMogball   const ConstantValue &latticeValue = lattice->getValue();
449432fbfeSMogball   if (!latticeValue.getConstantValue())
45d07c90e3SRiver Riddle     return failure();
46152d29ccSRiver Riddle 
47d07c90e3SRiver Riddle   // Attempt to materialize a constant for the given value.
489432fbfeSMogball   Dialect *dialect = latticeValue.getConstantDialect();
499432fbfeSMogball   Value constant = folder.getOrCreateConstant(builder, dialect,
509432fbfeSMogball                                               latticeValue.getConstantValue(),
519432fbfeSMogball                                               value.getType(), value.getLoc());
52d07c90e3SRiver Riddle   if (!constant)
53d07c90e3SRiver Riddle     return failure();
54d07c90e3SRiver Riddle 
55d07c90e3SRiver Riddle   value.replaceAllUsesWith(constant);
56d07c90e3SRiver Riddle   return success();
572eda87dfSRiver Riddle }
582eda87dfSRiver Riddle 
59d07c90e3SRiver Riddle /// Rewrite the given regions using the computing analysis. This replaces the
60d07c90e3SRiver Riddle /// uses of all values that have been computed to be constant, and erases as
61d07c90e3SRiver Riddle /// many newly dead operations.
rewrite(DataFlowSolver & solver,MLIRContext * context,MutableArrayRef<Region> initialRegions)629432fbfeSMogball static void rewrite(DataFlowSolver &solver, MLIRContext *context,
63152d29ccSRiver Riddle                     MutableArrayRef<Region> initialRegions) {
64d07c90e3SRiver Riddle   SmallVector<Block *> worklist;
65152d29ccSRiver Riddle   auto addToWorklist = [&](MutableArrayRef<Region> regions) {
66152d29ccSRiver Riddle     for (Region &region : regions)
67d07c90e3SRiver Riddle       for (Block &block : llvm::reverse(region))
68152d29ccSRiver Riddle         worklist.push_back(&block);
69152d29ccSRiver Riddle   };
70152d29ccSRiver Riddle 
71152d29ccSRiver Riddle   // An operation folder used to create and unique constants.
72152d29ccSRiver Riddle   OperationFolder folder(context);
73152d29ccSRiver Riddle   OpBuilder builder(context);
74152d29ccSRiver Riddle 
75152d29ccSRiver Riddle   addToWorklist(initialRegions);
76152d29ccSRiver Riddle   while (!worklist.empty()) {
77152d29ccSRiver Riddle     Block *block = worklist.pop_back_val();
78152d29ccSRiver Riddle 
79152d29ccSRiver Riddle     for (Operation &op : llvm::make_early_inc_range(*block)) {
80152d29ccSRiver Riddle       builder.setInsertionPoint(&op);
81152d29ccSRiver Riddle 
82152d29ccSRiver Riddle       // Replace any result with constants.
83152d29ccSRiver Riddle       bool replacedAll = op.getNumResults() != 0;
84152d29ccSRiver Riddle       for (Value res : op.getResults())
85d07c90e3SRiver Riddle         replacedAll &=
869432fbfeSMogball             succeeded(replaceWithConstant(solver, builder, folder, res));
87152d29ccSRiver Riddle 
88152d29ccSRiver Riddle       // If all of the results of the operation were replaced, try to erase
89152d29ccSRiver Riddle       // the operation completely.
90152d29ccSRiver Riddle       if (replacedAll && wouldOpBeTriviallyDead(&op)) {
91152d29ccSRiver Riddle         assert(op.use_empty() && "expected all uses to be replaced");
92152d29ccSRiver Riddle         op.erase();
93152d29ccSRiver Riddle         continue;
94152d29ccSRiver Riddle       }
95152d29ccSRiver Riddle 
96152d29ccSRiver Riddle       // Add any the regions of this operation to the worklist.
97152d29ccSRiver Riddle       addToWorklist(op.getRegions());
98152d29ccSRiver Riddle     }
99d07c90e3SRiver Riddle 
100d07c90e3SRiver Riddle     // Replace any block arguments with constants.
101d07c90e3SRiver Riddle     builder.setInsertionPointToStart(block);
102d07c90e3SRiver Riddle     for (BlockArgument arg : block->getArguments())
1039432fbfeSMogball       (void)replaceWithConstant(solver, builder, folder, arg);
104152d29ccSRiver Riddle   }
105152d29ccSRiver Riddle }
106152d29ccSRiver Riddle 
107152d29ccSRiver Riddle //===----------------------------------------------------------------------===//
108152d29ccSRiver Riddle // SCCP Pass
109152d29ccSRiver Riddle //===----------------------------------------------------------------------===//
110152d29ccSRiver Riddle 
111152d29ccSRiver Riddle namespace {
112152d29ccSRiver Riddle struct SCCP : public SCCPBase<SCCP> {
113152d29ccSRiver Riddle   void runOnOperation() override;
114152d29ccSRiver Riddle };
115be0a7e9fSMehdi Amini } // namespace
116152d29ccSRiver Riddle 
runOnOperation()117152d29ccSRiver Riddle void SCCP::runOnOperation() {
118152d29ccSRiver Riddle   Operation *op = getOperation();
119152d29ccSRiver Riddle 
1209432fbfeSMogball   DataFlowSolver solver;
1219432fbfeSMogball   solver.load<DeadCodeAnalysis>();
1229432fbfeSMogball   solver.load<SparseConstantPropagation>();
1239432fbfeSMogball   if (failed(solver.initializeAndRun(op)))
1249432fbfeSMogball     return signalPassFailure();
1259432fbfeSMogball   rewrite(solver, op->getContext(), op->getRegions());
126152d29ccSRiver Riddle }
127152d29ccSRiver Riddle 
createSCCPPass()128152d29ccSRiver Riddle std::unique_ptr<Pass> mlir::createSCCPPass() {
129152d29ccSRiver Riddle   return std::make_unique<SCCP>();
130152d29ccSRiver Riddle }
131