1 //===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/Pass/Pass.h" 12 13 using namespace mlir; 14 15 namespace { 16 /// This is a symbol test pass that tests the symbol uselist functionality 17 /// provided by the symbol table along with erasing from the symbol table. 18 struct SymbolUsesPass 19 : public PassWrapper<SymbolUsesPass, OperationPass<ModuleOp>> { 20 StringRef getArgument() const final { return "test-symbol-uses"; } 21 StringRef getDescription() const final { 22 return "Test detection of symbol uses"; 23 } 24 WalkResult operateOnSymbol(Operation *symbol, ModuleOp module, 25 SmallVectorImpl<FuncOp> &deadFunctions) { 26 // Test computing uses on a non symboltable op. 27 Optional<SymbolTable::UseRange> symbolUses = 28 SymbolTable::getSymbolUses(symbol); 29 30 // Test the conservative failure case. 31 if (!symbolUses) { 32 symbol->emitRemark() 33 << "symbol contains an unknown nested operation that " 34 "'may' define a new symbol table"; 35 return WalkResult::interrupt(); 36 } 37 if (unsigned numUses = llvm::size(*symbolUses)) 38 symbol->emitRemark() << "symbol contains " << numUses 39 << " nested references"; 40 41 // Test the functionality of symbolKnownUseEmpty. 42 if (SymbolTable::symbolKnownUseEmpty(symbol, &module.getBodyRegion())) { 43 FuncOp funcSymbol = dyn_cast<FuncOp>(symbol); 44 if (funcSymbol && funcSymbol.isExternal()) 45 deadFunctions.push_back(funcSymbol); 46 47 symbol->emitRemark() << "symbol has no uses"; 48 return WalkResult::advance(); 49 } 50 51 // Test the functionality of getSymbolUses. 52 symbolUses = SymbolTable::getSymbolUses(symbol, &module.getBodyRegion()); 53 assert(symbolUses.hasValue() && "expected no unknown operations"); 54 for (SymbolTable::SymbolUse symbolUse : *symbolUses) { 55 // Check that we can resolve back to our symbol. 56 if (SymbolTable::lookupNearestSymbolFrom( 57 symbolUse.getUser()->getParentOp(), symbolUse.getSymbolRef())) { 58 symbolUse.getUser()->emitRemark() 59 << "found use of symbol : " << symbolUse.getSymbolRef() << " : " 60 << symbol->getAttr(SymbolTable::getSymbolAttrName()); 61 } 62 } 63 symbol->emitRemark() << "symbol has " << llvm::size(*symbolUses) << " uses"; 64 return WalkResult::advance(); 65 } 66 67 void runOnOperation() override { 68 auto module = getOperation(); 69 70 // Walk nested symbols. 71 SmallVector<FuncOp, 4> deadFunctions; 72 module.getBodyRegion().walk([&](Operation *nestedOp) { 73 if (isa<SymbolOpInterface>(nestedOp)) 74 return operateOnSymbol(nestedOp, module, deadFunctions); 75 return WalkResult::advance(); 76 }); 77 78 SymbolTable table(module); 79 for (Operation *op : deadFunctions) { 80 // In order to test the SymbolTable::erase method, also erase completely 81 // useless functions. 82 auto name = SymbolTable::getSymbolName(op); 83 assert(table.lookup(name) && "expected no unknown operations"); 84 table.erase(op); 85 assert(!table.lookup(name) && 86 "expected erased operation to be unknown now"); 87 module.emitRemark() << name.getValue() << " function successfully erased"; 88 } 89 } 90 }; 91 92 /// This is a symbol test pass that tests the symbol use replacement 93 /// functionality provided by the symbol table. 94 struct SymbolReplacementPass 95 : public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> { 96 StringRef getArgument() const final { return "test-symbol-rauw"; } 97 StringRef getDescription() const final { 98 return "Test replacement of symbol uses"; 99 } 100 void runOnOperation() override { 101 ModuleOp module = getOperation(); 102 103 // Don't try to replace if we can't collect symbol uses. 104 if (!SymbolTable::getSymbolUses(&module.getBodyRegion())) 105 return; 106 107 SymbolTableCollection symbolTable; 108 SymbolUserMap symbolUsers(symbolTable, module); 109 module.getBodyRegion().walk([&](Operation *nestedOp) { 110 StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name"); 111 if (!newName) 112 return; 113 symbolUsers.replaceAllUsesWith(nestedOp, newName); 114 SymbolTable::setSymbolName(nestedOp, newName); 115 }); 116 } 117 }; 118 } // namespace 119 120 namespace mlir { 121 void registerSymbolTestPasses() { 122 PassRegistration<SymbolUsesPass>(); 123 124 PassRegistration<SymbolReplacementPass>(); 125 } 126 } // namespace mlir 127