1 //===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "TestDialect.h" 10 #include "mlir/IR/BuiltinOps.h" 11 #include "mlir/Pass/Pass.h" 12 13 using namespace mlir; 14 15 namespace { 16 /// This is a symbol test pass that tests the symbol uselist functionality 17 /// provided by the symbol table along with erasing from the symbol table. 18 struct SymbolUsesPass 19 : public PassWrapper<SymbolUsesPass, OperationPass<ModuleOp>> { 20 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SymbolUsesPass) 21 22 StringRef getArgument() const final { return "test-symbol-uses"; } 23 StringRef getDescription() const final { 24 return "Test detection of symbol uses"; 25 } 26 WalkResult operateOnSymbol(Operation *symbol, ModuleOp module, 27 SmallVectorImpl<func::FuncOp> &deadFunctions) { 28 // Test computing uses on a non symboltable op. 29 Optional<SymbolTable::UseRange> symbolUses = 30 SymbolTable::getSymbolUses(symbol); 31 32 // Test the conservative failure case. 33 if (!symbolUses) { 34 symbol->emitRemark() 35 << "symbol contains an unknown nested operation that " 36 "'may' define a new symbol table"; 37 return WalkResult::interrupt(); 38 } 39 if (unsigned numUses = llvm::size(*symbolUses)) 40 symbol->emitRemark() << "symbol contains " << numUses 41 << " nested references"; 42 43 // Test the functionality of symbolKnownUseEmpty. 44 if (SymbolTable::symbolKnownUseEmpty(symbol, &module.getBodyRegion())) { 45 func::FuncOp funcSymbol = dyn_cast<func::FuncOp>(symbol); 46 if (funcSymbol && funcSymbol.isExternal()) 47 deadFunctions.push_back(funcSymbol); 48 49 symbol->emitRemark() << "symbol has no uses"; 50 return WalkResult::advance(); 51 } 52 53 // Test the functionality of getSymbolUses. 54 symbolUses = SymbolTable::getSymbolUses(symbol, &module.getBodyRegion()); 55 assert(symbolUses.hasValue() && "expected no unknown operations"); 56 for (SymbolTable::SymbolUse symbolUse : *symbolUses) { 57 // Check that we can resolve back to our symbol. 58 if (SymbolTable::lookupNearestSymbolFrom( 59 symbolUse.getUser()->getParentOp(), symbolUse.getSymbolRef())) { 60 symbolUse.getUser()->emitRemark() 61 << "found use of symbol : " << symbolUse.getSymbolRef() << " : " 62 << symbol->getAttr(SymbolTable::getSymbolAttrName()); 63 } 64 } 65 symbol->emitRemark() << "symbol has " << llvm::size(*symbolUses) << " uses"; 66 return WalkResult::advance(); 67 } 68 69 void runOnOperation() override { 70 auto module = getOperation(); 71 72 // Walk nested symbols. 73 SmallVector<func::FuncOp, 4> deadFunctions; 74 module.getBodyRegion().walk([&](Operation *nestedOp) { 75 if (isa<SymbolOpInterface>(nestedOp)) 76 return operateOnSymbol(nestedOp, module, deadFunctions); 77 return WalkResult::advance(); 78 }); 79 80 SymbolTable table(module); 81 for (Operation *op : deadFunctions) { 82 // In order to test the SymbolTable::erase method, also erase completely 83 // useless functions. 84 auto name = SymbolTable::getSymbolName(op); 85 assert(table.lookup(name) && "expected no unknown operations"); 86 table.erase(op); 87 assert(!table.lookup(name) && 88 "expected erased operation to be unknown now"); 89 module.emitRemark() << name.getValue() << " function successfully erased"; 90 } 91 } 92 }; 93 94 /// This is a symbol test pass that tests the symbol use replacement 95 /// functionality provided by the symbol table. 96 struct SymbolReplacementPass 97 : public PassWrapper<SymbolReplacementPass, OperationPass<ModuleOp>> { 98 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SymbolReplacementPass) 99 100 StringRef getArgument() const final { return "test-symbol-rauw"; } 101 StringRef getDescription() const final { 102 return "Test replacement of symbol uses"; 103 } 104 void runOnOperation() override { 105 ModuleOp module = getOperation(); 106 107 // Don't try to replace if we can't collect symbol uses. 108 if (!SymbolTable::getSymbolUses(&module.getBodyRegion())) 109 return; 110 111 SymbolTableCollection symbolTable; 112 SymbolUserMap symbolUsers(symbolTable, module); 113 module.getBodyRegion().walk([&](Operation *nestedOp) { 114 StringAttr newName = nestedOp->getAttrOfType<StringAttr>("sym.new_name"); 115 if (!newName) 116 return; 117 symbolUsers.replaceAllUsesWith(nestedOp, newName); 118 SymbolTable::setSymbolName(nestedOp, newName); 119 }); 120 } 121 }; 122 } // namespace 123 124 namespace mlir { 125 void registerSymbolTestPasses() { 126 PassRegistration<SymbolUsesPass>(); 127 128 PassRegistration<SymbolReplacementPass>(); 129 } 130 } // namespace mlir 131