1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the the SPIR-V module combiner library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/ModuleCombiner.h" 14 15 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/SymbolTable.h" 18 #include "llvm/ADT/ArrayRef.h" 19 #include "llvm/ADT/StringExtras.h" 20 21 using namespace mlir; 22 23 static constexpr unsigned maxFreeID = 1 << 20; 24 25 static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID, 26 spirv::ModuleOp combinedModule) { 27 SmallString<64> newSymName(oldSymName); 28 newSymName.push_back('_'); 29 30 while (lastUsedID < maxFreeID) { 31 std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str(); 32 33 if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) { 34 newSymName += llvm::utostr(lastUsedID); 35 break; 36 } 37 } 38 39 return newSymName; 40 } 41 42 /// Check if a symbol with the same name as op already exists in source. If so, 43 /// rename op and update all its references in target. 44 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, 45 spirv::ModuleOp target, 46 spirv::ModuleOp source, 47 unsigned &lastUsedID) { 48 if (!SymbolTable::lookupSymbolIn(source, op.getName())) 49 return success(); 50 51 StringRef oldSymName = op.getName(); 52 SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target); 53 54 if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target))) 55 return op.emitError("unable to update all symbol uses for ") 56 << oldSymName << " to " << newSymName; 57 58 SymbolTable::setSymbolName(op, newSymName); 59 return success(); 60 } 61 62 namespace mlir { 63 namespace spirv { 64 65 // TODO Properly test symbol rename listener mechanism. 66 67 OwningSPIRVModuleRef 68 combine(llvm::MutableArrayRef<spirv::ModuleOp> modules, 69 OpBuilder &combinedModuleBuilder, 70 llvm::function_ref<void(ModuleOp, StringRef, StringRef)> 71 symRenameListener) { 72 unsigned lastUsedID = 0; 73 74 if (modules.empty()) 75 return nullptr; 76 77 auto addressingModel = modules[0].addressing_model(); 78 auto memoryModel = modules[0].memory_model(); 79 80 auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>( 81 modules[0].getLoc(), addressingModel, memoryModel); 82 combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody()); 83 84 // In some cases, a symbol in the (current state of the) combined module is 85 // renamed in order to maintain the conflicting symbol in the input module 86 // being merged. For example, if the conflict is between a global variable in 87 // the current combined module and a function in the input module, the global 88 // varaible is renamed. In order to notify listeners of the symbol updates in 89 // such cases, we need to keep track of the module from which the renamed 90 // symbol in the combined module originated. This map keeps such information. 91 DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap; 92 93 for (auto module : modules) { 94 if (module.addressing_model() != addressingModel || 95 module.memory_model() != memoryModel) { 96 module.emitError( 97 "input modules differ in addressing model and/or memory model"); 98 return nullptr; 99 } 100 101 spirv::ModuleOp moduleClone = module.clone(); 102 103 // In the combined module, rename all symbols that conflict with symbols 104 // from the current input module. This renmaing applies to all ops except 105 // for spv.funcs. This way, if the conflicting op in the input module is 106 // non-spv.func, we rename that symbol instead and maintain the spv.func in 107 // the combined module name as it is. 108 for (auto &op : combinedModule.getBlock().without_terminator()) { 109 auto symbolOp = dyn_cast<SymbolOpInterface>(op); 110 if (!symbolOp) 111 continue; 112 113 StringRef oldSymName = symbolOp.getName(); 114 115 if (!isa<FuncOp>(op) && 116 failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone, 117 lastUsedID))) 118 return nullptr; 119 120 StringRef newSymName = symbolOp.getName(); 121 122 if (symRenameListener && oldSymName != newSymName) { 123 spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName); 124 125 if (!originalModule) { 126 module.emitError("unable to find original ModuleOp for symbol ") 127 << oldSymName; 128 return nullptr; 129 } 130 131 symRenameListener(originalModule, oldSymName, newSymName); 132 133 // Since the symbol name is updated, there is no need to maintain the 134 // entry that assocaites the old symbol name with the original module. 135 symNameToModuleMap.erase(oldSymName); 136 // Instead, add a new entry to map the new symbol name to the original 137 // module in case it gets renamed again later. 138 symNameToModuleMap[newSymName] = originalModule; 139 } 140 } 141 142 // In the current input module, rename all symbols that conflict with 143 // symbols from the combined module. This includes renaming spv.funcs. 144 for (auto &op : moduleClone.getBlock().without_terminator()) { 145 auto symbolOp = dyn_cast<SymbolOpInterface>(op); 146 if (!symbolOp) 147 continue; 148 149 StringRef oldSymName = symbolOp.getName(); 150 151 if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule, 152 lastUsedID))) 153 return nullptr; 154 155 StringRef newSymName = symbolOp.getName(); 156 157 if (symRenameListener && oldSymName != newSymName) { 158 symRenameListener(module, oldSymName, newSymName); 159 160 // Insert the module associated with the symbol name. 161 auto emplaceResult = symNameToModuleMap.try_emplace(newSymName, module); 162 163 // If an entry with the same symbol name is already present, this must 164 // be a problem with the implementation, specially clean-up of the map 165 // while iterating over the combined module above. 166 if (!emplaceResult.second) { 167 module.emitError("did not expect to find an entry for symbol ") 168 << newSymName; 169 return nullptr; 170 } 171 } 172 } 173 174 // Clone all the module's ops to the combined module. 175 for (auto &op : moduleClone.getBlock().without_terminator()) 176 combinedModuleBuilder.insert(op.clone()); 177 } 178 179 return combinedModule; 180 } 181 182 } // namespace spirv 183 } // namespace mlir 184