1 //===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This transformation pass performs a sparse conditional constant propagation 10 // in MLIR. It identifies values known to be constant, propagates that 11 // information throughout the IR, and replaces them. This is done with an 12 // optimistic dataflow analysis that assumes that all values are constant until 13 // proven otherwise. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "PassDetail.h" 18 #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" 19 #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/Dialect.h" 22 #include "mlir/Interfaces/SideEffectInterfaces.h" 23 #include "mlir/Pass/Pass.h" 24 #include "mlir/Transforms/FoldUtils.h" 25 #include "mlir/Transforms/Passes.h" 26 27 using namespace mlir; 28 using namespace mlir::dataflow; 29 30 //===----------------------------------------------------------------------===// 31 // SCCP Rewrites 32 //===----------------------------------------------------------------------===// 33 34 /// Replace the given value with a constant if the corresponding lattice 35 /// represents a constant. Returns success if the value was replaced, failure 36 /// otherwise. 37 static LogicalResult replaceWithConstant(DataFlowSolver &solver, 38 OpBuilder &builder, 39 OperationFolder &folder, Value value) { 40 auto *lattice = solver.lookupState<Lattice<ConstantValue>>(value); 41 if (!lattice || lattice->isUninitialized()) 42 return failure(); 43 const ConstantValue &latticeValue = lattice->getValue(); 44 if (!latticeValue.getConstantValue()) 45 return failure(); 46 47 // Attempt to materialize a constant for the given value. 48 Dialect *dialect = latticeValue.getConstantDialect(); 49 Value constant = folder.getOrCreateConstant(builder, dialect, 50 latticeValue.getConstantValue(), 51 value.getType(), value.getLoc()); 52 if (!constant) 53 return failure(); 54 55 value.replaceAllUsesWith(constant); 56 return success(); 57 } 58 59 /// Rewrite the given regions using the computing analysis. This replaces the 60 /// uses of all values that have been computed to be constant, and erases as 61 /// many newly dead operations. 62 static void rewrite(DataFlowSolver &solver, MLIRContext *context, 63 MutableArrayRef<Region> initialRegions) { 64 SmallVector<Block *> worklist; 65 auto addToWorklist = [&](MutableArrayRef<Region> regions) { 66 for (Region ®ion : regions) 67 for (Block &block : llvm::reverse(region)) 68 worklist.push_back(&block); 69 }; 70 71 // An operation folder used to create and unique constants. 72 OperationFolder folder(context); 73 OpBuilder builder(context); 74 75 addToWorklist(initialRegions); 76 while (!worklist.empty()) { 77 Block *block = worklist.pop_back_val(); 78 79 for (Operation &op : llvm::make_early_inc_range(*block)) { 80 builder.setInsertionPoint(&op); 81 82 // Replace any result with constants. 83 bool replacedAll = op.getNumResults() != 0; 84 for (Value res : op.getResults()) 85 replacedAll &= 86 succeeded(replaceWithConstant(solver, builder, folder, res)); 87 88 // If all of the results of the operation were replaced, try to erase 89 // the operation completely. 90 if (replacedAll && wouldOpBeTriviallyDead(&op)) { 91 assert(op.use_empty() && "expected all uses to be replaced"); 92 op.erase(); 93 continue; 94 } 95 96 // Add any the regions of this operation to the worklist. 97 addToWorklist(op.getRegions()); 98 } 99 100 // Replace any block arguments with constants. 101 builder.setInsertionPointToStart(block); 102 for (BlockArgument arg : block->getArguments()) 103 (void)replaceWithConstant(solver, builder, folder, arg); 104 } 105 } 106 107 //===----------------------------------------------------------------------===// 108 // SCCP Pass 109 //===----------------------------------------------------------------------===// 110 111 namespace { 112 struct SCCP : public SCCPBase<SCCP> { 113 void runOnOperation() override; 114 }; 115 } // namespace 116 117 void SCCP::runOnOperation() { 118 Operation *op = getOperation(); 119 120 DataFlowSolver solver; 121 solver.load<DeadCodeAnalysis>(); 122 solver.load<SparseConstantPropagation>(); 123 if (failed(solver.initializeAndRun(op))) 124 return signalPassFailure(); 125 rewrite(solver, op->getContext(), op->getRegions()); 126 } 127 128 std::unique_ptr<Pass> mlir::createSCCPPass() { 129 return std::make_unique<SCCP>(); 130 } 131