1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the the SPIR-V module combiner library. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/SPIRV/ModuleCombiner.h" 14 15 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/SymbolTable.h" 18 #include "llvm/ADT/ArrayRef.h" 19 #include "llvm/ADT/StringExtras.h" 20 21 using namespace mlir; 22 23 static constexpr unsigned maxFreeID = 1 << 20; 24 25 static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID, 26 spirv::ModuleOp combinedModule) { 27 SmallString<64> newSymName(oldSymName); 28 newSymName.push_back('_'); 29 30 while (lastUsedID < maxFreeID) { 31 std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str(); 32 33 if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) { 34 newSymName += llvm::utostr(lastUsedID); 35 break; 36 } 37 } 38 39 return newSymName; 40 } 41 42 /// Check if a symbol with the same name as op already exists in source. If so, 43 /// rename op and update all its references in target. 44 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, 45 spirv::ModuleOp target, 46 spirv::ModuleOp source, 47 unsigned &lastUsedID) { 48 if (!SymbolTable::lookupSymbolIn(source, op.getName())) 49 return success(); 50 51 StringRef oldSymName = op.getName(); 52 SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target); 53 54 if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target))) 55 return op.emitError("unable to update all symbol uses for ") 56 << oldSymName << " to " << newSymName; 57 58 SymbolTable::setSymbolName(op, newSymName); 59 return success(); 60 } 61 62 namespace mlir { 63 namespace spirv { 64 65 // TODO Properly test symbol rename listener mechanism. 66 67 OwningSPIRVModuleRef 68 combine(llvm::MutableArrayRef<spirv::ModuleOp> modules, 69 OpBuilder &combinedModuleBuilder, 70 llvm::function_ref<void(ModuleOp, StringRef, StringRef)> 71 symRenameListener) { 72 unsigned lastUsedID = 0; 73 74 if (modules.empty()) 75 return nullptr; 76 77 auto addressingModel = modules[0].addressing_model(); 78 auto memoryModel = modules[0].memory_model(); 79 80 auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>( 81 modules[0].getLoc(), addressingModel, memoryModel); 82 combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody()); 83 84 // In some cases, a symbol in the (current state of the) combined module is 85 // renamed in order to maintain the conflicting symbol in the input module 86 // being merged. For example, if the conflict is between a global variable in 87 // the current combined module and a function in the input module, the global 88 // varaible is renamed. In order to notify listeners of the symbol updates in 89 // such cases, we need to keep track of the module from which the renamed 90 // symbol in the combined module originated. This map keeps such information. 91 DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap; 92 93 for (auto module : modules) { 94 if (module.addressing_model() != addressingModel || 95 module.memory_model() != memoryModel) { 96 module.emitError( 97 "input modules differ in addressing model and/or memory model"); 98 return nullptr; 99 } 100 101 spirv::ModuleOp moduleClone = module.clone(); 102 103 // In the combined module, rename all symbols that conflict with symbols 104 // from the current input module. This renmaing applies to all ops except 105 // for spv.funcs. This way, if the conflicting op in the input module is 106 // non-spv.func, we rename that symbol instead and maintain the spv.func in 107 // the combined module name as it is. 108 for (auto &op : combinedModule.getBlock().without_terminator()) { 109 if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) { 110 StringRef oldSymName = symbolOp.getName(); 111 112 if (!isa<FuncOp>(op) && 113 failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone, 114 lastUsedID))) 115 return nullptr; 116 117 StringRef newSymName = symbolOp.getName(); 118 119 if (symRenameListener && oldSymName != newSymName) { 120 spirv::ModuleOp originalModule = 121 symNameToModuleMap.lookup(oldSymName); 122 123 if (!originalModule) { 124 module.emitError("unable to find original ModuleOp for symbol ") 125 << oldSymName; 126 return nullptr; 127 } 128 129 symRenameListener(originalModule, oldSymName, newSymName); 130 131 // Since the symbol name is updated, there is no need to maintain the 132 // entry that assocaites the old symbol name with the original module. 133 symNameToModuleMap.erase(oldSymName); 134 // Instead, add a new entry to map the new symbol name to the original 135 // module in case it gets renamed again later. 136 symNameToModuleMap[newSymName] = originalModule; 137 } 138 } 139 } 140 141 // In the current input module, rename all symbols that conflict with 142 // symbols from the combined module. This includes renaming spv.funcs. 143 for (auto &op : moduleClone.getBlock().without_terminator()) { 144 if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) { 145 StringRef oldSymName = symbolOp.getName(); 146 147 if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule, 148 lastUsedID))) 149 return nullptr; 150 151 StringRef newSymName = symbolOp.getName(); 152 153 if (symRenameListener && oldSymName != newSymName) { 154 symRenameListener(module, oldSymName, newSymName); 155 156 // Insert the module associated with the symbol name. 157 auto emplaceResult = 158 symNameToModuleMap.try_emplace(symbolOp.getName(), module); 159 160 // If an entry with the same symbol name is already present, this must 161 // be a problem with the implementation, specially clean-up of the map 162 // while iterating over the combined module above. 163 if (!emplaceResult.second) { 164 module.emitError("did not expect to find an entry for symbol ") 165 << symbolOp.getName(); 166 return nullptr; 167 } 168 } 169 } 170 } 171 172 // Clone all the module's ops to the combined module. 173 for (auto &op : moduleClone.getBlock().without_terminator()) 174 combinedModuleBuilder.insert(op.clone()); 175 } 176 177 return combinedModule; 178 } 179 180 } // namespace spirv 181 } // namespace mlir 182