1 //===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an algorithm for eliminating symbol operations that are
10 // known to be dead.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/Passes.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 struct SymbolDCE : public OperationPass<SymbolDCE> {
21   void runOnOperation() override;
22 
23   /// Compute the liveness of the symbols within the given symbol table.
24   /// `symbolTableIsHidden` is true if this symbol table is known to be
25   /// unaccessible from operations in its parent regions.
26   LogicalResult computeLiveness(Operation *symbolTableOp,
27                                 bool symbolTableIsHidden,
28                                 DenseSet<Operation *> &liveSymbols);
29 };
30 } // end anonymous namespace
31 
32 void SymbolDCE::runOnOperation() {
33   Operation *symbolTableOp = getOperation();
34 
35   // SymbolDCE should only be run on operations that define a symbol table.
36   if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
37     symbolTableOp->emitOpError()
38         << " was scheduled to run under SymbolDCE, but does not define a "
39            "symbol table";
40     return signalPassFailure();
41   }
42 
43   // A flag that signals if the top level symbol table is hidden, i.e. not
44   // accessible from parent scopes.
45   bool symbolTableIsHidden = true;
46   if (symbolTableOp->getParentOp() && SymbolTable::isSymbol(symbolTableOp)) {
47     symbolTableIsHidden = SymbolTable::getSymbolVisibility(symbolTableOp) ==
48                           SymbolTable::Visibility::Private;
49   }
50 
51   // Compute the set of live symbols within the symbol table.
52   DenseSet<Operation *> liveSymbols;
53   if (failed(computeLiveness(symbolTableOp, symbolTableIsHidden, liveSymbols)))
54     return signalPassFailure();
55 
56   // After computing the liveness, delete all of the symbols that were found to
57   // be dead.
58   symbolTableOp->walk([&](Operation *nestedSymbolTable) {
59     if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
60       return;
61     for (auto &block : nestedSymbolTable->getRegion(0)) {
62       for (Operation &op :
63            llvm::make_early_inc_range(block.without_terminator())) {
64         if (SymbolTable::isSymbol(&op) && !liveSymbols.count(&op))
65           op.erase();
66       }
67     }
68   });
69 }
70 
71 /// Compute the liveness of the symbols within the given symbol table.
72 /// `symbolTableIsHidden` is true if this symbol table is known to be
73 /// unaccessible from operations in its parent regions.
74 LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
75                                          bool symbolTableIsHidden,
76                                          DenseSet<Operation *> &liveSymbols) {
77   // A worklist of live operations to propagate uses from.
78   SmallVector<Operation *, 16> worklist;
79 
80   // Walk the symbols within the current symbol table, marking the symbols that
81   // are known to be live.
82   for (auto &block : symbolTableOp->getRegion(0)) {
83     for (Operation &op : block.without_terminator()) {
84       // Always add non symbol operations to the worklist.
85       if (!SymbolTable::isSymbol(&op)) {
86         worklist.push_back(&op);
87         continue;
88       }
89 
90       // Check the visibility to see if this symbol may be referenced
91       // externally.
92       SymbolTable::Visibility visibility =
93           SymbolTable::getSymbolVisibility(&op);
94 
95       // Private symbols are always initially considered dead.
96       if (visibility == mlir::SymbolTable::Visibility::Private)
97         continue;
98       // We only include nested visibility here if the symbol table isn't
99       // hidden.
100       if (symbolTableIsHidden && visibility == SymbolTable::Visibility::Nested)
101         continue;
102 
103       // TODO(riverriddle) Add hooks here to allow symbols to provide additional
104       // information, e.g. linkage can be used to drop some symbols that may
105       // otherwise be considered "live".
106       if (liveSymbols.insert(&op).second)
107         worklist.push_back(&op);
108     }
109   }
110 
111   // Process the set of symbols that were known to be live, adding new symbols
112   // that are referenced within.
113   while (!worklist.empty()) {
114     Operation *op = worklist.pop_back_val();
115 
116     // If this is a symbol table, recursively compute its liveness.
117     if (op->hasTrait<OpTrait::SymbolTable>()) {
118       // The internal symbol table is hidden if the parent is, if its not a
119       // symbol, or if it is a private symbol.
120       bool symbolIsHidden = symbolTableIsHidden || !SymbolTable::isSymbol(op) ||
121                             SymbolTable::getSymbolVisibility(op) ==
122                                 SymbolTable::Visibility::Private;
123       if (failed(computeLiveness(op, symbolIsHidden, liveSymbols)))
124         return failure();
125     }
126 
127     // Collect the uses held by this operation.
128     Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
129     if (!uses) {
130       return op->emitError()
131              << "operation contains potentially unknown symbol table, "
132                 "meaning that we can't reliable compute symbol uses";
133     }
134 
135     SmallVector<Operation *, 4> resolvedSymbols;
136     for (const SymbolTable::SymbolUse &use : *uses) {
137       // Lookup the symbols referenced by this use.
138       resolvedSymbols.clear();
139       if (failed(SymbolTable::lookupSymbolIn(
140               op->getParentOp(), use.getSymbolRef(), resolvedSymbols))) {
141         return use.getUser()->emitError()
142                << "unable to resolve reference to symbol "
143                << use.getSymbolRef();
144       }
145 
146       // Mark each of the resolved symbols as live.
147       for (Operation *resolvedSymbol : resolvedSymbols)
148         if (liveSymbols.insert(resolvedSymbol).second)
149           worklist.push_back(resolvedSymbol);
150     }
151   }
152 
153   return success();
154 }
155 
156 std::unique_ptr<Pass> mlir::createSymbolDCEPass() {
157   return std::make_unique<SymbolDCE>();
158 }
159 
160 static PassRegistration<SymbolDCE> pass("symbol-dce", "Eliminate dead symbols");
161