190a8260cSergawy //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===// 290a8260cSergawy // 390a8260cSergawy // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 490a8260cSergawy // See https://llvm.org/LICENSE.txt for license information. 590a8260cSergawy // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 690a8260cSergawy // 790a8260cSergawy //===----------------------------------------------------------------------===// 890a8260cSergawy // 990a8260cSergawy // This file implements the the SPIR-V module combiner library. 1090a8260cSergawy // 1190a8260cSergawy //===----------------------------------------------------------------------===// 1290a8260cSergawy 13*01178654SLei Zhang #include "mlir/Dialect/SPIRV/Linking/ModuleCombiner.h" 1490a8260cSergawy 15*01178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16*01178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 1790a8260cSergawy #include "mlir/IR/Builders.h" 1890a8260cSergawy #include "mlir/IR/SymbolTable.h" 1990a8260cSergawy #include "llvm/ADT/ArrayRef.h" 20341f3c11Sergawy #include "llvm/ADT/Hashing.h" 2190a8260cSergawy #include "llvm/ADT/StringExtras.h" 2290a8260cSergawy 2390a8260cSergawy using namespace mlir; 2490a8260cSergawy 2590a8260cSergawy static constexpr unsigned maxFreeID = 1 << 20; 2690a8260cSergawy 2790a8260cSergawy static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID, 2890a8260cSergawy spirv::ModuleOp combinedModule) { 2990a8260cSergawy SmallString<64> newSymName(oldSymName); 3090a8260cSergawy newSymName.push_back('_'); 3190a8260cSergawy 3290a8260cSergawy while (lastUsedID < maxFreeID) { 3390a8260cSergawy std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str(); 3490a8260cSergawy 3590a8260cSergawy if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) { 3690a8260cSergawy newSymName += llvm::utostr(lastUsedID); 3790a8260cSergawy break; 3890a8260cSergawy } 3990a8260cSergawy } 4090a8260cSergawy 4190a8260cSergawy return newSymName; 4290a8260cSergawy } 4390a8260cSergawy 4490a8260cSergawy /// Check if a symbol with the same name as op already exists in source. If so, 4590a8260cSergawy /// rename op and update all its references in target. 4690a8260cSergawy static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op, 4790a8260cSergawy spirv::ModuleOp target, 4890a8260cSergawy spirv::ModuleOp source, 4990a8260cSergawy unsigned &lastUsedID) { 5090a8260cSergawy if (!SymbolTable::lookupSymbolIn(source, op.getName())) 5190a8260cSergawy return success(); 5290a8260cSergawy 5390a8260cSergawy StringRef oldSymName = op.getName(); 5490a8260cSergawy SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target); 5590a8260cSergawy 5690a8260cSergawy if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target))) 5790a8260cSergawy return op.emitError("unable to update all symbol uses for ") 5890a8260cSergawy << oldSymName << " to " << newSymName; 5990a8260cSergawy 6090a8260cSergawy SymbolTable::setSymbolName(op, newSymName); 6190a8260cSergawy return success(); 6290a8260cSergawy } 6390a8260cSergawy 64341f3c11Sergawy template <typename KeyTy, typename SymbolOpTy> 65341f3c11Sergawy static SymbolOpTy 66341f3c11Sergawy emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp, 67341f3c11Sergawy DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) { 68341f3c11Sergawy auto result = deduplicationMap.try_emplace(key, symbolOp); 69341f3c11Sergawy 70341f3c11Sergawy if (result.second) 71341f3c11Sergawy return SymbolOpTy(); 72341f3c11Sergawy 73341f3c11Sergawy return result.first->second; 74341f3c11Sergawy } 75341f3c11Sergawy 76341f3c11Sergawy /// Computes a hash code to represent the argument SymbolOpInterface based on 77341f3c11Sergawy /// all the Op's attributes except for the symbol name. 78341f3c11Sergawy /// 79341f3c11Sergawy /// \return the hash code computed from the Op's attributes as described above. 80341f3c11Sergawy /// 81341f3c11Sergawy /// Note: We use the operation's name (not the symbol name) as part of the hash 82341f3c11Sergawy /// computation. This prevents, for example, mistakenly considering a global 83341f3c11Sergawy /// variable and a spec constant as duplicates because their descriptor set + 84341f3c11Sergawy /// binding and spec_id, repectively, happen to hash to the same value. 85341f3c11Sergawy static llvm::hash_code computeHash(SymbolOpInterface symbolOp) { 86341f3c11Sergawy llvm::hash_code hashCode(0); 87c4a04059SChristian Sigg hashCode = llvm::hash_combine(symbolOp->getName()); 88341f3c11Sergawy 89c4a04059SChristian Sigg for (auto attr : symbolOp->getAttrs()) { 90341f3c11Sergawy if (attr.first == SymbolTable::getSymbolAttrName()) 91341f3c11Sergawy continue; 92341f3c11Sergawy hashCode = llvm::hash_combine(hashCode, attr); 93341f3c11Sergawy } 94341f3c11Sergawy 95341f3c11Sergawy return hashCode; 96341f3c11Sergawy } 97341f3c11Sergawy 98341f3c11Sergawy /// Computes a hash code from the argument Block. 99341f3c11Sergawy llvm::hash_code computeHash(Block *block) { 100341f3c11Sergawy // TODO: Consider extracting BlockEquivalenceData into a common header and 101341f3c11Sergawy // re-using it here. 102341f3c11Sergawy llvm::hash_code hash(0); 103341f3c11Sergawy 104341f3c11Sergawy for (Operation &op : *block) { 105341f3c11Sergawy // TODO: Properly handle operations with regions. 106341f3c11Sergawy if (op.getNumRegions() > 0) 107341f3c11Sergawy return 0; 108341f3c11Sergawy 109341f3c11Sergawy hash = llvm::hash_combine( 110341f3c11Sergawy hash, OperationEquivalence::computeHash( 111341f3c11Sergawy &op, OperationEquivalence::Flags::IgnoreOperands)); 112341f3c11Sergawy } 113341f3c11Sergawy 114341f3c11Sergawy return hash; 115341f3c11Sergawy } 116341f3c11Sergawy 11790a8260cSergawy namespace mlir { 11890a8260cSergawy namespace spirv { 11990a8260cSergawy 12090a8260cSergawy // TODO Properly test symbol rename listener mechanism. 12190a8260cSergawy 12290a8260cSergawy OwningSPIRVModuleRef 12390a8260cSergawy combine(llvm::MutableArrayRef<spirv::ModuleOp> modules, 12490a8260cSergawy OpBuilder &combinedModuleBuilder, 12590a8260cSergawy llvm::function_ref<void(ModuleOp, StringRef, StringRef)> 12690a8260cSergawy symRenameListener) { 12790a8260cSergawy unsigned lastUsedID = 0; 12890a8260cSergawy 12990a8260cSergawy if (modules.empty()) 13090a8260cSergawy return nullptr; 13190a8260cSergawy 13290a8260cSergawy auto addressingModel = modules[0].addressing_model(); 13390a8260cSergawy auto memoryModel = modules[0].memory_model(); 13490a8260cSergawy 13590a8260cSergawy auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>( 13690a8260cSergawy modules[0].getLoc(), addressingModel, memoryModel); 13790a8260cSergawy combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody()); 13890a8260cSergawy 13990a8260cSergawy // In some cases, a symbol in the (current state of the) combined module is 14090a8260cSergawy // renamed in order to maintain the conflicting symbol in the input module 14190a8260cSergawy // being merged. For example, if the conflict is between a global variable in 14290a8260cSergawy // the current combined module and a function in the input module, the global 14390a8260cSergawy // varaible is renamed. In order to notify listeners of the symbol updates in 14490a8260cSergawy // such cases, we need to keep track of the module from which the renamed 14590a8260cSergawy // symbol in the combined module originated. This map keeps such information. 14690a8260cSergawy DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap; 14790a8260cSergawy 14890a8260cSergawy for (auto module : modules) { 14990a8260cSergawy if (module.addressing_model() != addressingModel || 15090a8260cSergawy module.memory_model() != memoryModel) { 15190a8260cSergawy module.emitError( 15290a8260cSergawy "input modules differ in addressing model and/or memory model"); 15390a8260cSergawy return nullptr; 15490a8260cSergawy } 15590a8260cSergawy 15690a8260cSergawy spirv::ModuleOp moduleClone = module.clone(); 15790a8260cSergawy 15890a8260cSergawy // In the combined module, rename all symbols that conflict with symbols 15990a8260cSergawy // from the current input module. This renmaing applies to all ops except 16090a8260cSergawy // for spv.funcs. This way, if the conflicting op in the input module is 16190a8260cSergawy // non-spv.func, we rename that symbol instead and maintain the spv.func in 16290a8260cSergawy // the combined module name as it is. 16390a8260cSergawy for (auto &op : combinedModule.getBlock().without_terminator()) { 16490a8260cSergawy if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) { 16590a8260cSergawy StringRef oldSymName = symbolOp.getName(); 16690a8260cSergawy 16790a8260cSergawy if (!isa<FuncOp>(op) && 16890a8260cSergawy failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone, 16990a8260cSergawy lastUsedID))) 17090a8260cSergawy return nullptr; 17190a8260cSergawy 17290a8260cSergawy StringRef newSymName = symbolOp.getName(); 17390a8260cSergawy 17490a8260cSergawy if (symRenameListener && oldSymName != newSymName) { 17590a8260cSergawy spirv::ModuleOp originalModule = 17690a8260cSergawy symNameToModuleMap.lookup(oldSymName); 17790a8260cSergawy 17890a8260cSergawy if (!originalModule) { 17990a8260cSergawy module.emitError("unable to find original ModuleOp for symbol ") 18090a8260cSergawy << oldSymName; 18190a8260cSergawy return nullptr; 18290a8260cSergawy } 18390a8260cSergawy 18490a8260cSergawy symRenameListener(originalModule, oldSymName, newSymName); 18590a8260cSergawy 18690a8260cSergawy // Since the symbol name is updated, there is no need to maintain the 18790a8260cSergawy // entry that assocaites the old symbol name with the original module. 18890a8260cSergawy symNameToModuleMap.erase(oldSymName); 18990a8260cSergawy // Instead, add a new entry to map the new symbol name to the original 19090a8260cSergawy // module in case it gets renamed again later. 19190a8260cSergawy symNameToModuleMap[newSymName] = originalModule; 19290a8260cSergawy } 19390a8260cSergawy } 19490a8260cSergawy } 19590a8260cSergawy 19690a8260cSergawy // In the current input module, rename all symbols that conflict with 19790a8260cSergawy // symbols from the combined module. This includes renaming spv.funcs. 19890a8260cSergawy for (auto &op : moduleClone.getBlock().without_terminator()) { 19990a8260cSergawy if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) { 20090a8260cSergawy StringRef oldSymName = symbolOp.getName(); 20190a8260cSergawy 20290a8260cSergawy if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule, 20390a8260cSergawy lastUsedID))) 20490a8260cSergawy return nullptr; 20590a8260cSergawy 20690a8260cSergawy StringRef newSymName = symbolOp.getName(); 20790a8260cSergawy 20890a8260cSergawy if (symRenameListener && oldSymName != newSymName) { 20990a8260cSergawy symRenameListener(module, oldSymName, newSymName); 21090a8260cSergawy 21190a8260cSergawy // Insert the module associated with the symbol name. 21290a8260cSergawy auto emplaceResult = 21390a8260cSergawy symNameToModuleMap.try_emplace(symbolOp.getName(), module); 21490a8260cSergawy 21590a8260cSergawy // If an entry with the same symbol name is already present, this must 21690a8260cSergawy // be a problem with the implementation, specially clean-up of the map 21790a8260cSergawy // while iterating over the combined module above. 21890a8260cSergawy if (!emplaceResult.second) { 21990a8260cSergawy module.emitError("did not expect to find an entry for symbol ") 22090a8260cSergawy << symbolOp.getName(); 22190a8260cSergawy return nullptr; 22290a8260cSergawy } 22390a8260cSergawy } 22490a8260cSergawy } 22590a8260cSergawy } 22690a8260cSergawy 22790a8260cSergawy // Clone all the module's ops to the combined module. 22890a8260cSergawy for (auto &op : moduleClone.getBlock().without_terminator()) 22990a8260cSergawy combinedModuleBuilder.insert(op.clone()); 23090a8260cSergawy } 23190a8260cSergawy 232341f3c11Sergawy // Deduplicate identical global variables, spec constants, and functions. 233341f3c11Sergawy DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp; 234341f3c11Sergawy SmallVector<SymbolOpInterface, 0> eraseList; 235341f3c11Sergawy 236341f3c11Sergawy for (auto &op : combinedModule.getBlock().without_terminator()) { 237341f3c11Sergawy llvm::hash_code hashCode(0); 238341f3c11Sergawy SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op); 239341f3c11Sergawy 240341f3c11Sergawy if (!symbolOp) 241341f3c11Sergawy continue; 242341f3c11Sergawy 243341f3c11Sergawy hashCode = computeHash(symbolOp); 244341f3c11Sergawy 245341f3c11Sergawy // A 0 hash code means the op is not suitable for deduplication and should 246341f3c11Sergawy // be skipped. An example of this is when a function has ops with regions 247341f3c11Sergawy // which are not properly supported yet. 248341f3c11Sergawy if (!hashCode) 249341f3c11Sergawy continue; 250341f3c11Sergawy 251341f3c11Sergawy if (auto funcOp = dyn_cast<FuncOp>(op)) 252341f3c11Sergawy for (auto &blk : funcOp) 253341f3c11Sergawy hashCode = llvm::hash_combine(hashCode, computeHash(&blk)); 254341f3c11Sergawy 255341f3c11Sergawy SymbolOpInterface replacementSymOp = 256341f3c11Sergawy emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp); 257341f3c11Sergawy 258341f3c11Sergawy if (!replacementSymOp) 259341f3c11Sergawy continue; 260341f3c11Sergawy 261341f3c11Sergawy if (failed(SymbolTable::replaceAllSymbolUses( 262341f3c11Sergawy symbolOp, replacementSymOp.getName(), combinedModule))) { 263341f3c11Sergawy symbolOp.emitError("unable to update all symbol uses for ") 264341f3c11Sergawy << symbolOp.getName() << " to " << replacementSymOp.getName(); 265341f3c11Sergawy return nullptr; 266341f3c11Sergawy } 267341f3c11Sergawy 268341f3c11Sergawy eraseList.push_back(symbolOp); 269341f3c11Sergawy } 270341f3c11Sergawy 271341f3c11Sergawy for (auto symbolOp : eraseList) 272341f3c11Sergawy symbolOp.erase(); 273341f3c11Sergawy 27490a8260cSergawy return combinedModule; 27590a8260cSergawy } 27690a8260cSergawy 27790a8260cSergawy } // namespace spirv 27890a8260cSergawy } // namespace mlir 279