1 //===- SymbolDCE.cpp - Pass to delete dead symbols ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an algorithm for eliminating symbol operations that are
10 // known to be dead.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Transforms/Passes.h"
16 
17 using namespace mlir;
18 
19 namespace {
20 struct SymbolDCE : public SymbolDCEBase<SymbolDCE> {
21   void runOnOperation() override;
22 
23   /// Compute the liveness of the symbols within the given symbol table.
24   /// `symbolTableIsHidden` is true if this symbol table is known to be
25   /// unaccessible from operations in its parent regions.
26   LogicalResult computeLiveness(Operation *symbolTableOp,
27                                 SymbolTableCollection &symbolTable,
28                                 bool symbolTableIsHidden,
29                                 DenseSet<Operation *> &liveSymbols);
30 };
31 } // namespace
32 
33 void SymbolDCE::runOnOperation() {
34   Operation *symbolTableOp = getOperation();
35 
36   // SymbolDCE should only be run on operations that define a symbol table.
37   if (!symbolTableOp->hasTrait<OpTrait::SymbolTable>()) {
38     symbolTableOp->emitOpError()
39         << " was scheduled to run under SymbolDCE, but does not define a "
40            "symbol table";
41     return signalPassFailure();
42   }
43 
44   // A flag that signals if the top level symbol table is hidden, i.e. not
45   // accessible from parent scopes.
46   bool symbolTableIsHidden = true;
47   SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(symbolTableOp);
48   if (symbolTableOp->getParentOp() && symbol)
49     symbolTableIsHidden = symbol.isPrivate();
50 
51   // Compute the set of live symbols within the symbol table.
52   DenseSet<Operation *> liveSymbols;
53   SymbolTableCollection symbolTable;
54   if (failed(computeLiveness(symbolTableOp, symbolTable, symbolTableIsHidden,
55                              liveSymbols)))
56     return signalPassFailure();
57 
58   // After computing the liveness, delete all of the symbols that were found to
59   // be dead.
60   symbolTableOp->walk([&](Operation *nestedSymbolTable) {
61     if (!nestedSymbolTable->hasTrait<OpTrait::SymbolTable>())
62       return;
63     for (auto &block : nestedSymbolTable->getRegion(0)) {
64       for (Operation &op : llvm::make_early_inc_range(block)) {
65         if (isa<SymbolOpInterface>(&op) && !liveSymbols.count(&op)) {
66           op.erase();
67           ++numDCE;
68         }
69       }
70     }
71   });
72 }
73 
74 /// Compute the liveness of the symbols within the given symbol table.
75 /// `symbolTableIsHidden` is true if this symbol table is known to be
76 /// unaccessible from operations in its parent regions.
77 LogicalResult SymbolDCE::computeLiveness(Operation *symbolTableOp,
78                                          SymbolTableCollection &symbolTable,
79                                          bool symbolTableIsHidden,
80                                          DenseSet<Operation *> &liveSymbols) {
81   // A worklist of live operations to propagate uses from.
82   SmallVector<Operation *, 16> worklist;
83 
84   // Walk the symbols within the current symbol table, marking the symbols that
85   // are known to be live.
86   for (auto &block : symbolTableOp->getRegion(0)) {
87     // Add all non-symbols or symbols that can't be discarded.
88     for (Operation &op : block) {
89       SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(&op);
90       if (!symbol) {
91         worklist.push_back(&op);
92         continue;
93       }
94       bool isDiscardable = (symbolTableIsHidden || symbol.isPrivate()) &&
95                            symbol.canDiscardOnUseEmpty();
96       if (!isDiscardable && liveSymbols.insert(&op).second)
97         worklist.push_back(&op);
98     }
99   }
100 
101   // Process the set of symbols that were known to be live, adding new symbols
102   // that are referenced within.
103   while (!worklist.empty()) {
104     Operation *op = worklist.pop_back_val();
105 
106     // If this is a symbol table, recursively compute its liveness.
107     if (op->hasTrait<OpTrait::SymbolTable>()) {
108       // The internal symbol table is hidden if the parent is, if its not a
109       // symbol, or if it is a private symbol.
110       SymbolOpInterface symbol = dyn_cast<SymbolOpInterface>(op);
111       bool symIsHidden = symbolTableIsHidden || !symbol || symbol.isPrivate();
112       if (failed(computeLiveness(op, symbolTable, symIsHidden, liveSymbols)))
113         return failure();
114     }
115 
116     // Collect the uses held by this operation.
117     Optional<SymbolTable::UseRange> uses = SymbolTable::getSymbolUses(op);
118     if (!uses) {
119       return op->emitError()
120              << "operation contains potentially unknown symbol table, "
121                 "meaning that we can't reliable compute symbol uses";
122     }
123 
124     SmallVector<Operation *, 4> resolvedSymbols;
125     for (const SymbolTable::SymbolUse &use : *uses) {
126       // Lookup the symbols referenced by this use.
127       resolvedSymbols.clear();
128       if (failed(symbolTable.lookupSymbolIn(
129               op->getParentOp(), use.getSymbolRef(), resolvedSymbols)))
130         // Ignore references to unknown symbols.
131         continue;
132 
133       // Mark each of the resolved symbols as live.
134       for (Operation *resolvedSymbol : resolvedSymbols)
135         if (liveSymbols.insert(resolvedSymbol).second)
136           worklist.push_back(resolvedSymbol);
137     }
138   }
139 
140   return success();
141 }
142 
143 std::unique_ptr<Pass> mlir::createSymbolDCEPass() {
144   return std::make_unique<SymbolDCE>();
145 }
146