1 //===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "TestDialect.h"
10 #include "mlir/IR/BuiltinOps.h"
11 #include "mlir/Pass/Pass.h"
12 
13 using namespace mlir;
14 
15 namespace {
16 /// This is a symbol test pass that tests the symbol uselist functionality
17 /// provided by the symbol table along with erasing from the symbol table.
18 struct SymbolUsesPass
19     : public PassWrapper<SymbolUsesPass, OperationPass<ModuleOp>> {
20   StringRef getArgument() const final { return "test-symbol-uses"; }
21   StringRef getDescription() const final {
22     return "Test detection of symbol uses";
23   }
24   WalkResult operateOnSymbol(Operation *symbol, ModuleOp module,
25                              SmallVectorImpl<FuncOp> &deadFunctions) {
26     // Test computing uses on a non symboltable op.
27     Optional<SymbolTable::UseRange> symbolUses =
28         SymbolTable::getSymbolUses(symbol);
29 
30     // Test the conservative failure case.
31     if (!symbolUses) {
32       symbol->emitRemark()
33           << "symbol contains an unknown nested operation that "
34              "'may' define a new symbol table";
35       return WalkResult::interrupt();
36     }
37     if (unsigned numUses = llvm::size(*symbolUses))
38       symbol->emitRemark() << "symbol contains " << numUses
39                            << " nested references";
40 
41     // Test the functionality of symbolKnownUseEmpty.
42     if (SymbolTable::symbolKnownUseEmpty(symbol, &module.getBodyRegion())) {
43       FuncOp funcSymbol = dyn_cast<FuncOp>(symbol);
44       if (funcSymbol && funcSymbol.isExternal())
45         deadFunctions.push_back(funcSymbol);
46 
47       symbol->emitRemark() << "symbol has no uses";
48       return WalkResult::advance();
49     }
50 
51     // Test the functionality of getSymbolUses.
52     symbolUses = SymbolTable::getSymbolUses(symbol, &module.getBodyRegion());
53     assert(symbolUses.hasValue() && "expected no unknown operations");
54     for (SymbolTable::SymbolUse symbolUse : *symbolUses) {
55       // Check that we can resolve back to our symbol.
56       if (SymbolTable::lookupNearestSymbolFrom(
57               symbolUse.getUser()->getParentOp(), symbolUse.getSymbolRef())) {
58         symbolUse.getUser()->emitRemark()
59             << "found use of symbol : " << symbolUse.getSymbolRef() << " : "
60             << symbol->getAttr(SymbolTable::getSymbolAttrName());
61       }
62     }
63     symbol->emitRemark() << "symbol has " << llvm::size(*symbolUses) << " uses";
64     return WalkResult::advance();
65   }
66 
67   void runOnOperation() override {
68     auto module = getOperation();
69 
70     // Walk nested symbols.
71     SmallVector<FuncOp, 4> deadFunctions;
72     module.getBodyRegion().walk([&](Operation *nestedOp) {
73       if (isa<SymbolOpInterface>(nestedOp))
74         return operateOnSymbol(nestedOp, module, deadFunctions);
75       return WalkResult::advance();
76     });
77 
78     SymbolTable table(module);
79     for (Operation *op : deadFunctions) {
80       // In order to test the SymbolTable::erase method, also erase completely
81       // useless functions.
82       auto name = SymbolTable::getSymbolName(op);
83       assert(table.lookup(name) && "expected no unknown operations");
84       table.erase(op);
85       assert(!table.lookup(name) &&
86              "expected erased operation to be unknown now");
87       module.emitRemark() << name.getValue() << " function successfully erased";
88     }
89   }
90 };
91 
92 /// This is a symbol test pass that tests the symbol use replacement
93 /// functionality provided by the symbol table.
94 struct SymbolReplacementPass
95     : public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> {
96   StringRef getArgument() const final { return "test-symbol-rauw"; }
97   StringRef getDescription() const final {
98     return "Test replacement of symbol uses";
99   }
100   void runOnOperation() override {
101     ModuleOp module = getOperation();
102 
103     // Don't try to replace if we can't collect symbol uses.
104     if (!SymbolTable::getSymbolUses(&module.getBodyRegion()))
105       return;
106 
107     SymbolTableCollection symbolTable;
108     SymbolUserMap symbolUsers(symbolTable, module);
109     module.getBodyRegion().walk([&](Operation *nestedOp) {
110       StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name");
111       if (!newName)
112         return;
113       symbolUsers.replaceAllUsesWith(nestedOp, newName);
114       SymbolTable::setSymbolName(nestedOp, newName);
115     });
116   }
117 };
118 } // namespace
119 
120 namespace mlir {
121 void registerSymbolTestPasses() {
122   PassRegistration<SymbolUsesPass>();
123 
124   PassRegistration<SymbolReplacementPass>();
125 }
126 } // namespace mlir
127