190a8260cSergawy //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===// 290a8260cSergawy // 390a8260cSergawy // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 490a8260cSergawy // See https://llvm.org/LICENSE.txt for license information. 590a8260cSergawy // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 690a8260cSergawy // 790a8260cSergawy //===----------------------------------------------------------------------===// 890a8260cSergawy // 9f88fab50SKazuaki Ishizaki // This file implements the SPIR-V module combiner library. 1090a8260cSergawy // 1190a8260cSergawy //===----------------------------------------------------------------------===// 1290a8260cSergawy 1301178654SLei Zhang #include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h" 1490a8260cSergawy 15*23326b9fSLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 1701178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 18*23326b9fSLei Zhang #include "mlir/IR/Attributes.h" 1990a8260cSergawy #include "mlir/IR/Builders.h" 2090a8260cSergawy #include "mlir/IR/SymbolTable.h" 2190a8260cSergawy #include "llvm/ADT/ArrayRef.h" 22341f3c11Sergawy #include "llvm/ADT/Hashing.h" 23*23326b9fSLei Zhang #include "llvm/ADT/STLExtras.h" 2490a8260cSergawy #include "llvm/ADT/StringExtras.h" 25*23326b9fSLei Zhang #include "llvm/ADT/StringMap.h" 2690a8260cSergawy 2790a8260cSergawy using namespace mlir; 2890a8260cSergawy 2990a8260cSergawy static constexpr unsigned maxFreeID = 1 << 20; 3090a8260cSergawy 31*23326b9fSLei Zhang /// Returns an unsed symbol in `module` for `oldSymbolName` by trying numeric 32*23326b9fSLei Zhang /// suffix in `lastUsedID`. 3390a8260cSergawy static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID, 34*23326b9fSLei Zhang spirv::ModuleOp module) { 3590a8260cSergawy SmallString<64> newSymName(oldSymName); 3690a8260cSergawy newSymName.push_back('_'); 3790a8260cSergawy 3890a8260cSergawy while (lastUsedID < maxFreeID) { 3990a8260cSergawy std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str(); 4090a8260cSergawy 41*23326b9fSLei Zhang if (!SymbolTable::lookupSymbolIn(module, possible)) { 4290a8260cSergawy newSymName += llvm::utostr(lastUsedID); 4390a8260cSergawy break; 4490a8260cSergawy } 4590a8260cSergawy } 4690a8260cSergawy 4790a8260cSergawy return newSymName; 4890a8260cSergawy } 4990a8260cSergawy 50*23326b9fSLei Zhang /// Checks if a symbol with the same name as `op` already exists in `source`. 51*23326b9fSLei Zhang /// If so, renames `op` and updates all its references in `target`. 5290a8260cSergawy static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, 5390a8260cSergawy spirv::ModuleOp target, 5490a8260cSergawy spirv::ModuleOp source, 5590a8260cSergawy unsigned &lastUsedID) { 5690a8260cSergawy if (!SymbolTable::lookupSymbolIn(source, op.getName())) 5790a8260cSergawy return success(); 5890a8260cSergawy 5990a8260cSergawy StringRef oldSymName = op.getName(); 6090a8260cSergawy SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target); 6190a8260cSergawy 6290a8260cSergawy if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target))) 6390a8260cSergawy return op.emitError("unable to update all symbol uses for ") 6490a8260cSergawy << oldSymName << " to " << newSymName; 6590a8260cSergawy 6690a8260cSergawy SymbolTable::setSymbolName(op, newSymName); 6790a8260cSergawy return success(); 6890a8260cSergawy } 6990a8260cSergawy 70*23326b9fSLei Zhang /// Computes a hash code to represent `symbolOp` based on all its attributes 71*23326b9fSLei Zhang /// except for the symbol name. 72341f3c11Sergawy /// 73341f3c11Sergawy /// Note: We use the operation's name (not the symbol name) as part of the hash 74341f3c11Sergawy /// computation. This prevents, for example, mistakenly considering a global 75341f3c11Sergawy /// variable and a spec constant as duplicates because their descriptor set + 76f88fab50SKazuaki Ishizaki /// binding and spec_id, respectively, happen to hash to the same value. 77341f3c11Sergawy static llvm::hash_code computeHash(SymbolOpInterface symbolOp) { 78*23326b9fSLei Zhang auto range = 79*23326b9fSLei Zhang llvm::make_filter_range(symbolOp->getAttrs(), [](NamedAttribute attr) { 80*23326b9fSLei Zhang return attr.first != SymbolTable::getSymbolAttrName(); 81*23326b9fSLei Zhang }); 82341f3c11Sergawy 83*23326b9fSLei Zhang return llvm::hash_combine( 84*23326b9fSLei Zhang symbolOp->getName(), 85*23326b9fSLei Zhang llvm::hash_combine_range(range.begin(), range.end())); 86341f3c11Sergawy } 87341f3c11Sergawy 8890a8260cSergawy namespace mlir { 8990a8260cSergawy namespace spirv { 9090a8260cSergawy 91*23326b9fSLei Zhang OwningOpRef<spirv::ModuleOp> combine(ArrayRef<spirv::ModuleOp> inputModules, 9290a8260cSergawy OpBuilder &combinedModuleBuilder, 93*23326b9fSLei Zhang SymbolRenameListener symRenameListener) { 94*23326b9fSLei Zhang if (inputModules.empty()) 9590a8260cSergawy return nullptr; 9690a8260cSergawy 97*23326b9fSLei Zhang spirv::ModuleOp firstModule = inputModules.front(); 98*23326b9fSLei Zhang auto addressingModel = firstModule.addressing_model(); 99*23326b9fSLei Zhang auto memoryModel = firstModule.memory_model(); 100*23326b9fSLei Zhang auto vceTriple = firstModule.vce_triple(); 101*23326b9fSLei Zhang 102*23326b9fSLei Zhang // First check whether there are conflicts between addressing/memory model. 103*23326b9fSLei Zhang // Return early if so. 104*23326b9fSLei Zhang for (auto module : inputModules) { 105*23326b9fSLei Zhang if (module.addressing_model() != addressingModel || 106*23326b9fSLei Zhang module.memory_model() != memoryModel || 107*23326b9fSLei Zhang module.vce_triple() != vceTriple) { 108*23326b9fSLei Zhang module.emitError("input modules differ in addressing model, memory " 109*23326b9fSLei Zhang "model, and/or VCE triple"); 110*23326b9fSLei Zhang return nullptr; 111*23326b9fSLei Zhang } 112*23326b9fSLei Zhang } 11390a8260cSergawy 11490a8260cSergawy auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>( 115*23326b9fSLei Zhang firstModule.getLoc(), addressingModel, memoryModel, vceTriple); 11656f60a1cSLei Zhang combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody()); 11790a8260cSergawy 11890a8260cSergawy // In some cases, a symbol in the (current state of the) combined module is 119*23326b9fSLei Zhang // renamed in order to enable the conflicting symbol in the input module 12090a8260cSergawy // being merged. For example, if the conflict is between a global variable in 12190a8260cSergawy // the current combined module and a function in the input module, the global 122f88fab50SKazuaki Ishizaki // variable is renamed. In order to notify listeners of the symbol updates in 12390a8260cSergawy // such cases, we need to keep track of the module from which the renamed 12490a8260cSergawy // symbol in the combined module originated. This map keeps such information. 125*23326b9fSLei Zhang llvm::StringMap<spirv::ModuleOp> symNameToModuleMap; 12690a8260cSergawy 127*23326b9fSLei Zhang unsigned lastUsedID = 0; 12890a8260cSergawy 129*23326b9fSLei Zhang for (auto inputModule : inputModules) { 130*23326b9fSLei Zhang spirv::ModuleOp moduleClone = inputModule.clone(); 13190a8260cSergawy 13290a8260cSergawy // In the combined module, rename all symbols that conflict with symbols 133f88fab50SKazuaki Ishizaki // from the current input module. This renaming applies to all ops except 13490a8260cSergawy // for spv.funcs. This way, if the conflicting op in the input module is 13590a8260cSergawy // non-spv.func, we rename that symbol instead and maintain the spv.func in 13690a8260cSergawy // the combined module name as it is. 13756f60a1cSLei Zhang for (auto &op : *combinedModule.getBody()) { 138*23326b9fSLei Zhang auto symbolOp = dyn_cast<SymbolOpInterface>(op); 139*23326b9fSLei Zhang if (!symbolOp) 140*23326b9fSLei Zhang continue; 141*23326b9fSLei Zhang 14290a8260cSergawy StringRef oldSymName = symbolOp.getName(); 14390a8260cSergawy 14490a8260cSergawy if (!isa<FuncOp>(op) && 14590a8260cSergawy failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone, 14690a8260cSergawy lastUsedID))) 14790a8260cSergawy return nullptr; 14890a8260cSergawy 14990a8260cSergawy StringRef newSymName = symbolOp.getName(); 15090a8260cSergawy 15190a8260cSergawy if (symRenameListener && oldSymName != newSymName) { 152*23326b9fSLei Zhang spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName); 15390a8260cSergawy 15490a8260cSergawy if (!originalModule) { 155*23326b9fSLei Zhang inputModule.emitError( 156*23326b9fSLei Zhang "unable to find original spirv::ModuleOp for symbol ") 15790a8260cSergawy << oldSymName; 15890a8260cSergawy return nullptr; 15990a8260cSergawy } 16090a8260cSergawy 16190a8260cSergawy symRenameListener(originalModule, oldSymName, newSymName); 16290a8260cSergawy 16390a8260cSergawy // Since the symbol name is updated, there is no need to maintain the 164f88fab50SKazuaki Ishizaki // entry that associates the old symbol name with the original module. 16590a8260cSergawy symNameToModuleMap.erase(oldSymName); 16690a8260cSergawy // Instead, add a new entry to map the new symbol name to the original 16790a8260cSergawy // module in case it gets renamed again later. 16890a8260cSergawy symNameToModuleMap[newSymName] = originalModule; 16990a8260cSergawy } 17090a8260cSergawy } 17190a8260cSergawy 17290a8260cSergawy // In the current input module, rename all symbols that conflict with 17390a8260cSergawy // symbols from the combined module. This includes renaming spv.funcs. 17456f60a1cSLei Zhang for (auto &op : *moduleClone.getBody()) { 175*23326b9fSLei Zhang auto symbolOp = dyn_cast<SymbolOpInterface>(op); 176*23326b9fSLei Zhang if (!symbolOp) 177*23326b9fSLei Zhang continue; 178*23326b9fSLei Zhang 17990a8260cSergawy StringRef oldSymName = symbolOp.getName(); 18090a8260cSergawy 18190a8260cSergawy if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule, 18290a8260cSergawy lastUsedID))) 18390a8260cSergawy return nullptr; 18490a8260cSergawy 18590a8260cSergawy StringRef newSymName = symbolOp.getName(); 18690a8260cSergawy 187*23326b9fSLei Zhang if (symRenameListener) { 188*23326b9fSLei Zhang if (oldSymName != newSymName) 189*23326b9fSLei Zhang symRenameListener(inputModule, oldSymName, newSymName); 19090a8260cSergawy 19190a8260cSergawy // Insert the module associated with the symbol name. 19290a8260cSergawy auto emplaceResult = 193*23326b9fSLei Zhang symNameToModuleMap.try_emplace(newSymName, inputModule); 19490a8260cSergawy 19590a8260cSergawy // If an entry with the same symbol name is already present, this must 19690a8260cSergawy // be a problem with the implementation, specially clean-up of the map 19790a8260cSergawy // while iterating over the combined module above. 19890a8260cSergawy if (!emplaceResult.second) { 199*23326b9fSLei Zhang inputModule.emitError("did not expect to find an entry for symbol ") 20090a8260cSergawy << symbolOp.getName(); 20190a8260cSergawy return nullptr; 20290a8260cSergawy } 20390a8260cSergawy } 20490a8260cSergawy } 20590a8260cSergawy 20690a8260cSergawy // Clone all the module's ops to the combined module. 20756f60a1cSLei Zhang for (auto &op : *moduleClone.getBody()) 20890a8260cSergawy combinedModuleBuilder.insert(op.clone()); 20990a8260cSergawy } 21090a8260cSergawy 211341f3c11Sergawy // Deduplicate identical global variables, spec constants, and functions. 212341f3c11Sergawy DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp; 213341f3c11Sergawy SmallVector<SymbolOpInterface, 0> eraseList; 214341f3c11Sergawy 21556f60a1cSLei Zhang for (auto &op : *combinedModule.getBody()) { 216341f3c11Sergawy SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op); 217341f3c11Sergawy if (!symbolOp) 218341f3c11Sergawy continue; 219341f3c11Sergawy 220*23326b9fSLei Zhang // Do not support ops with operands or results. 221*23326b9fSLei Zhang // Global variables, spec constants, and functions won't have 222*23326b9fSLei Zhang // operands/results, but just for safety here. 223*23326b9fSLei Zhang if (op.getNumOperands() != 0 || op.getNumResults() != 0) 224341f3c11Sergawy continue; 225341f3c11Sergawy 226*23326b9fSLei Zhang // Deduplicating functions are not supported yet. 227*23326b9fSLei Zhang if (isa<FuncOp>(op)) 228341f3c11Sergawy continue; 229341f3c11Sergawy 230*23326b9fSLei Zhang auto result = hashToSymbolOp.try_emplace(computeHash(symbolOp), symbolOp); 231*23326b9fSLei Zhang if (result.second) 232*23326b9fSLei Zhang continue; 233*23326b9fSLei Zhang 234*23326b9fSLei Zhang SymbolOpInterface replacementSymOp = result.first->second; 235*23326b9fSLei Zhang 236341f3c11Sergawy if (failed(SymbolTable::replaceAllSymbolUses( 237341f3c11Sergawy symbolOp, replacementSymOp.getName(), combinedModule))) { 238341f3c11Sergawy symbolOp.emitError("unable to update all symbol uses for ") 239341f3c11Sergawy << symbolOp.getName() << " to " << replacementSymOp.getName(); 240341f3c11Sergawy return nullptr; 241341f3c11Sergawy } 242341f3c11Sergawy 243341f3c11Sergawy eraseList.push_back(symbolOp); 244341f3c11Sergawy } 245341f3c11Sergawy 246341f3c11Sergawy for (auto symbolOp : eraseList) 247341f3c11Sergawy symbolOp.erase(); 248341f3c11Sergawy 24990a8260cSergawy return combinedModule; 25090a8260cSergawy } 25190a8260cSergawy 25290a8260cSergawy } // namespace spirv 25390a8260cSergawy } // namespace mlir 254