1 //===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an algorithm for eliminating symbol operations that are
10 // known to be dead.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/IR/SymbolTable.h"
16 #include "mlir/Transforms/Passes.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 struct SymbolDCE : public SymbolDCEBase<SymbolDCE> {
22   void runOnOperation() override;
23 
24   /// Compute the liveness of the symbols within the given symbol table.
25   /// `symbolTableIsHidden` is true if this symbol table is known to be
26   /// unaccessible from operations in its parent regions.
27   LogicalResult computeLiveness(Operation *symbolTableOp,
28                                 SymbolTableCollection &symbolTable,
29                                 bool symbolTableIsHidden,
30                                 DenseSet<Operation *> &liveSymbols);
31 };
32 } // namespace
33 
runOnOperation()34 void SymbolDCE::runOnOperation() {
35   Operation *symbolTableOp = getOperation();
36 
37   // SymbolDCE should only be run on operations that define a symbol table.
38   if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
39     symbolTableOp->emitOpError()
40         << " was scheduled to run under SymbolDCE, but does not define a "
41            "symbol table";
42     return signalPassFailure();
43   }
44 
45   // A flag that signals if the top level symbol table is hidden, i.e. not
46   // accessible from parent scopes.
47   bool symbolTableIsHidden = true;
48   SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
49   if (symbolTableOp->getParentOp() && symbol)
50     symbolTableIsHidden = symbol.isPrivate();
51 
52   // Compute the set of live symbols within the symbol table.
53   DenseSet<Operation *> liveSymbols;
54   SymbolTableCollection symbolTable;
55   if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
56                              liveSymbols)))
57     return signalPassFailure();
58 
59   // After computing the liveness, delete all of the symbols that were found to
60   // be dead.
61   symbolTableOp->walk([&](Operation *nestedSymbolTable) {
62     if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
63       return;
64     for (auto &block : nestedSymbolTable->getRegion(0)) {
65       for (Operation &op : llvm::make_early_inc_range(block)) {
66         if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op)) {
67           op.erase();
68           ++numDCE;
69         }
70       }
71     }
72   });
73 }
74 
75 /// Compute the liveness of the symbols within the given symbol table.
76 /// `symbolTableIsHidden` is true if this symbol table is known to be
77 /// unaccessible from operations in its parent regions.
computeLiveness(Operation * symbolTableOp,SymbolTableCollection & symbolTable,bool symbolTableIsHidden,DenseSet<Operation * > & liveSymbols)78 LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
79                                          SymbolTableCollection &symbolTable,
80                                          bool symbolTableIsHidden,
81                                          DenseSet<Operation *> &liveSymbols) {
82   // A worklist of live operations to propagate uses from.
83   SmallVector<Operation *, 16> worklist;
84 
85   // Walk the symbols within the current symbol table, marking the symbols that
86   // are known to be live.
87   for (auto &block : symbolTableOp->getRegion(0)) {
88     // Add all non-symbols or symbols that can't be discarded.
89     for (Operation &op : block) {
90       SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
91       if (!symbol) {
92         worklist.push_back(&op);
93         continue;
94       }
95       bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
96                            symbol.canDiscardOnUseEmpty();
97       if (!isDiscardable && liveSymbols.insert(&op).second)
98         worklist.push_back(&op);
99     }
100   }
101 
102   // Process the set of symbols that were known to be live, adding new symbols
103   // that are referenced within.
104   while (!worklist.empty()) {
105     Operation *op = worklist.pop_back_val();
106 
107     // If this is a symbol table, recursively compute its liveness.
108     if (op->hasTrait<OpTrait::SymbolTable>()) {
109       // The internal symbol table is hidden if the parent is, if its not a
110       // symbol, or if it is a private symbol.
111       SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
112       bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
113       if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
114         return failure();
115     }
116 
117     // Collect the uses held by this operation.
118     Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
119     if (!uses) {
120       return op->emitError()
121              << "operation contains potentially unknown symbol table, "
122                 "meaning that we can't reliable compute symbol uses";
123     }
124 
125     SmallVector<Operation *, 4> resolvedSymbols;
126     for (const SymbolTable::SymbolUse &use : *uses) {
127       // Lookup the symbols referenced by this use.
128       resolvedSymbols.clear();
129       if (failed(symbolTable.lookupSymbolIn(
130               op->getParentOp(), use.getSymbolRef(), resolvedSymbols)))
131         // Ignore references to unknown symbols.
132         continue;
133 
134       // Mark each of the resolved symbols as live.
135       for (Operation *resolvedSymbol : resolvedSymbols)
136         if (liveSymbols.insert(resolvedSymbol).second)
137           worklist.push_back(resolvedSymbol);
138     }
139   }
140 
141   return success();
142 }
143 
createSymbolDCEPass()144 std::unique_ptr<Pass> mlir::createSymbolDCEPass() {
145   return std::make_unique<SymbolDCE>();
146 }
147