1 //===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/Pass/Pass.h" 12 13 using namespace mlir; 14 15 namespace { 16 /// This is a symbol test pass that tests the symbol uselist functionality 17 /// provided by the symbol table along with erasing from the symbol table. 18 struct SymbolUsesPass 19 : public PassWrapper<SymbolUsesPass, OperationPass<ModuleOp>> { 20 WalkResult operateOnSymbol(Operation *symbol, ModuleOp module, 21 SmallVectorImpl<FuncOp> &deadFunctions) { 22 // Test computing uses on a non symboltable op. 23 Optional<SymbolTable::UseRange> symbolUses = 24 SymbolTable::getSymbolUses(symbol); 25 26 // Test the conservative failure case. 27 if (!symbolUses) { 28 symbol->emitRemark() 29 << "symbol contains an unknown nested operation that " 30 "'may' define a new symbol table"; 31 return WalkResult::interrupt(); 32 } 33 if (unsigned numUses = llvm::size(*symbolUses)) 34 symbol->emitRemark() << "symbol contains " << numUses 35 << " nested references"; 36 37 // Test the functionality of symbolKnownUseEmpty. 38 if (SymbolTable::symbolKnownUseEmpty(symbol, &module.getBodyRegion())) { 39 FuncOp funcSymbol = dyn_cast<FuncOp>(symbol); 40 if (funcSymbol && funcSymbol.isExternal()) 41 deadFunctions.push_back(funcSymbol); 42 43 symbol->emitRemark() << "symbol has no uses"; 44 return WalkResult::advance(); 45 } 46 47 // Test the functionality of getSymbolUses. 48 symbolUses = SymbolTable::getSymbolUses(symbol, &module.getBodyRegion()); 49 assert(symbolUses.hasValue() && "expected no unknown operations"); 50 for (SymbolTable::SymbolUse symbolUse : *symbolUses) { 51 // Check that we can resolve back to our symbol. 52 if (SymbolTable::lookupNearestSymbolFrom( 53 symbolUse.getUser()->getParentOp(), symbolUse.getSymbolRef())) { 54 symbolUse.getUser()->emitRemark() 55 << "found use of symbol : " << symbolUse.getSymbolRef() << " : " 56 << symbol->getAttr(SymbolTable::getSymbolAttrName()); 57 } 58 } 59 symbol->emitRemark() << "symbol has " << llvm::size(*symbolUses) << " uses"; 60 return WalkResult::advance(); 61 } 62 63 void runOnOperation() override { 64 auto module = getOperation(); 65 66 // Walk nested symbols. 67 SmallVector<FuncOp, 4> deadFunctions; 68 module.getBodyRegion().walk([&](Operation *nestedOp) { 69 if (isa<SymbolOpInterface>(nestedOp)) 70 return operateOnSymbol(nestedOp, module, deadFunctions); 71 return WalkResult::advance(); 72 }); 73 74 SymbolTable table(module); 75 for (Operation *op : deadFunctions) { 76 // In order to test the SymbolTable::erase method, also erase completely 77 // useless functions. 78 auto name = SymbolTable::getSymbolName(op); 79 assert(table.lookup(name) && "expected no unknown operations"); 80 table.erase(op); 81 assert(!table.lookup(name) && 82 "expected erased operation to be unknown now"); 83 module.emitRemark() << name << " function successfully erased"; 84 } 85 } 86 }; 87 88 /// This is a symbol test pass that tests the symbol use replacement 89 /// functionality provided by the symbol table. 90 struct SymbolReplacementPass 91 : public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> { 92 void runOnOperation() override { 93 auto module = getOperation(); 94 95 // Walk nested functions and modules. 96 module.getBodyRegion().walk([&](Operation *nestedOp) { 97 StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name"); 98 if (!newName) 99 return; 100 if (succeeded(SymbolTable::replaceAllSymbolUses( 101 nestedOp, newName.getValue(), &module.getBodyRegion()))) 102 SymbolTable::setSymbolName(nestedOp, newName.getValue()); 103 }); 104 } 105 }; 106 } // end anonymous namespace 107 108 namespace mlir { 109 void registerSymbolTestPasses() { 110 PassRegistration<SymbolUsesPass>("test-symbol-uses", 111 "Test detection of symbol uses"); 112 113 PassRegistration<SymbolReplacementPass>("test-symbol-rauw", 114 "Test replacement of symbol uses"); 115 } 116 } // namespace mlir 117