190a8260cSergawy //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
290a8260cSergawy //
390a8260cSergawy // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
490a8260cSergawy // See https://llvm.org/LICENSE.txt for license information.
590a8260cSergawy // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
690a8260cSergawy //
790a8260cSergawy //===----------------------------------------------------------------------===//
890a8260cSergawy //
990a8260cSergawy // This file implements the the SPIR-V module combiner library.
1090a8260cSergawy //
1190a8260cSergawy //===----------------------------------------------------------------------===//
1290a8260cSergawy 
1390a8260cSergawy #include "mlir/Dialect/SPIRV/ModuleCombiner.h"
1490a8260cSergawy 
15*341f3c11Sergawy #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
1690a8260cSergawy #include "mlir/Dialect/SPIRV/SPIRVOps.h"
1790a8260cSergawy #include "mlir/IR/Builders.h"
1890a8260cSergawy #include "mlir/IR/SymbolTable.h"
1990a8260cSergawy #include "llvm/ADT/ArrayRef.h"
20*341f3c11Sergawy #include "llvm/ADT/Hashing.h"
2190a8260cSergawy #include "llvm/ADT/StringExtras.h"
2290a8260cSergawy 
2390a8260cSergawy using namespace mlir;
2490a8260cSergawy 
2590a8260cSergawy static constexpr unsigned maxFreeID = 1 << 20;
2690a8260cSergawy 
2790a8260cSergawy static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
2890a8260cSergawy                                     spirv::ModuleOp combinedModule) {
2990a8260cSergawy   SmallString<64> newSymName(oldSymName);
3090a8260cSergawy   newSymName.push_back('_');
3190a8260cSergawy 
3290a8260cSergawy   while (lastUsedID < maxFreeID) {
3390a8260cSergawy     std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
3490a8260cSergawy 
3590a8260cSergawy     if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
3690a8260cSergawy       newSymName += llvm::utostr(lastUsedID);
3790a8260cSergawy       break;
3890a8260cSergawy     }
3990a8260cSergawy   }
4090a8260cSergawy 
4190a8260cSergawy   return newSymName;
4290a8260cSergawy }
4390a8260cSergawy 
4490a8260cSergawy /// Check if a symbol with the same name as op already exists in source. If so,
4590a8260cSergawy /// rename op and update all its references in target.
4690a8260cSergawy static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
4790a8260cSergawy                                             spirv::ModuleOp target,
4890a8260cSergawy                                             spirv::ModuleOp source,
4990a8260cSergawy                                             unsigned &lastUsedID) {
5090a8260cSergawy   if (!SymbolTable::lookupSymbolIn(source, op.getName()))
5190a8260cSergawy     return success();
5290a8260cSergawy 
5390a8260cSergawy   StringRef oldSymName = op.getName();
5490a8260cSergawy   SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);
5590a8260cSergawy 
5690a8260cSergawy   if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
5790a8260cSergawy     return op.emitError("unable to update all symbol uses for ")
5890a8260cSergawy            << oldSymName << " to " << newSymName;
5990a8260cSergawy 
6090a8260cSergawy   SymbolTable::setSymbolName(op, newSymName);
6190a8260cSergawy   return success();
6290a8260cSergawy }
6390a8260cSergawy 
64*341f3c11Sergawy template <typename KeyTy, typename SymbolOpTy>
65*341f3c11Sergawy static SymbolOpTy
66*341f3c11Sergawy emplaceOrGetReplacementSymbol(KeyTy key, SymbolOpTy symbolOp,
67*341f3c11Sergawy                               DenseMap<KeyTy, SymbolOpTy> &deduplicationMap) {
68*341f3c11Sergawy   auto result = deduplicationMap.try_emplace(key, symbolOp);
69*341f3c11Sergawy 
70*341f3c11Sergawy   if (result.second)
71*341f3c11Sergawy     return SymbolOpTy();
72*341f3c11Sergawy 
73*341f3c11Sergawy   return result.first->second;
74*341f3c11Sergawy }
75*341f3c11Sergawy 
76*341f3c11Sergawy /// Computes a hash code to represent the argument SymbolOpInterface based on
77*341f3c11Sergawy /// all the Op's attributes except for the symbol name.
78*341f3c11Sergawy ///
79*341f3c11Sergawy /// \return the hash code computed from the Op's attributes as described above.
80*341f3c11Sergawy ///
81*341f3c11Sergawy /// Note: We use the operation's name (not the symbol name) as part of the hash
82*341f3c11Sergawy /// computation. This prevents, for example, mistakenly considering a global
83*341f3c11Sergawy /// variable and a spec constant as duplicates because their descriptor set +
84*341f3c11Sergawy /// binding and spec_id, repectively, happen to hash to the same value.
85*341f3c11Sergawy static llvm::hash_code computeHash(SymbolOpInterface symbolOp) {
86*341f3c11Sergawy   llvm::hash_code hashCode(0);
87*341f3c11Sergawy   hashCode = llvm::hash_combine(symbolOp.getOperation()->getName());
88*341f3c11Sergawy 
89*341f3c11Sergawy   for (auto attr : symbolOp.getOperation()->getAttrs()) {
90*341f3c11Sergawy     if (attr.first == SymbolTable::getSymbolAttrName())
91*341f3c11Sergawy       continue;
92*341f3c11Sergawy     hashCode = llvm::hash_combine(hashCode, attr);
93*341f3c11Sergawy   }
94*341f3c11Sergawy 
95*341f3c11Sergawy   return hashCode;
96*341f3c11Sergawy }
97*341f3c11Sergawy 
98*341f3c11Sergawy /// Computes a hash code from the argument Block.
99*341f3c11Sergawy llvm::hash_code computeHash(Block *block) {
100*341f3c11Sergawy   // TODO: Consider extracting BlockEquivalenceData into a common header and
101*341f3c11Sergawy   // re-using it here.
102*341f3c11Sergawy   llvm::hash_code hash(0);
103*341f3c11Sergawy 
104*341f3c11Sergawy   for (Operation &op : *block) {
105*341f3c11Sergawy     // TODO: Properly handle operations with regions.
106*341f3c11Sergawy     if (op.getNumRegions() > 0)
107*341f3c11Sergawy       return 0;
108*341f3c11Sergawy 
109*341f3c11Sergawy     hash = llvm::hash_combine(
110*341f3c11Sergawy         hash, OperationEquivalence::computeHash(
111*341f3c11Sergawy                   &op, OperationEquivalence::Flags::IgnoreOperands));
112*341f3c11Sergawy   }
113*341f3c11Sergawy 
114*341f3c11Sergawy   return hash;
115*341f3c11Sergawy }
116*341f3c11Sergawy 
11790a8260cSergawy namespace mlir {
11890a8260cSergawy namespace spirv {
11990a8260cSergawy 
12090a8260cSergawy // TODO Properly test symbol rename listener mechanism.
12190a8260cSergawy 
12290a8260cSergawy OwningSPIRVModuleRef
12390a8260cSergawy combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
12490a8260cSergawy         OpBuilder &combinedModuleBuilder,
12590a8260cSergawy         llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
12690a8260cSergawy             symRenameListener) {
12790a8260cSergawy   unsigned lastUsedID = 0;
12890a8260cSergawy 
12990a8260cSergawy   if (modules.empty())
13090a8260cSergawy     return nullptr;
13190a8260cSergawy 
13290a8260cSergawy   auto addressingModel = modules[0].addressing_model();
13390a8260cSergawy   auto memoryModel = modules[0].memory_model();
13490a8260cSergawy 
13590a8260cSergawy   auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
13690a8260cSergawy       modules[0].getLoc(), addressingModel, memoryModel);
13790a8260cSergawy   combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
13890a8260cSergawy 
13990a8260cSergawy   // In some cases, a symbol in the (current state of the) combined module is
14090a8260cSergawy   // renamed in order to maintain the conflicting symbol in the input module
14190a8260cSergawy   // being merged. For example, if the conflict is between a global variable in
14290a8260cSergawy   // the current combined module and a function in the input module, the global
14390a8260cSergawy   // varaible is renamed. In order to notify listeners of the symbol updates in
14490a8260cSergawy   // such cases, we need to keep track of the module from which the renamed
14590a8260cSergawy   // symbol in the combined module originated. This map keeps such information.
14690a8260cSergawy   DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
14790a8260cSergawy 
14890a8260cSergawy   for (auto module : modules) {
14990a8260cSergawy     if (module.addressing_model() != addressingModel ||
15090a8260cSergawy         module.memory_model() != memoryModel) {
15190a8260cSergawy       module.emitError(
15290a8260cSergawy           "input modules differ in addressing model and/or memory model");
15390a8260cSergawy       return nullptr;
15490a8260cSergawy     }
15590a8260cSergawy 
15690a8260cSergawy     spirv::ModuleOp moduleClone = module.clone();
15790a8260cSergawy 
15890a8260cSergawy     // In the combined module, rename all symbols that conflict with symbols
15990a8260cSergawy     // from the current input module. This renmaing applies to all ops except
16090a8260cSergawy     // for spv.funcs. This way, if the conflicting op in the input module is
16190a8260cSergawy     // non-spv.func, we rename that symbol instead and maintain the spv.func in
16290a8260cSergawy     // the combined module name as it is.
16390a8260cSergawy     for (auto &op : combinedModule.getBlock().without_terminator()) {
16490a8260cSergawy       if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
16590a8260cSergawy         StringRef oldSymName = symbolOp.getName();
16690a8260cSergawy 
16790a8260cSergawy         if (!isa<FuncOp>(op) &&
16890a8260cSergawy             failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
16990a8260cSergawy                                           lastUsedID)))
17090a8260cSergawy           return nullptr;
17190a8260cSergawy 
17290a8260cSergawy         StringRef newSymName = symbolOp.getName();
17390a8260cSergawy 
17490a8260cSergawy         if (symRenameListener && oldSymName != newSymName) {
17590a8260cSergawy           spirv::ModuleOp originalModule =
17690a8260cSergawy               symNameToModuleMap.lookup(oldSymName);
17790a8260cSergawy 
17890a8260cSergawy           if (!originalModule) {
17990a8260cSergawy             module.emitError("unable to find original ModuleOp for symbol ")
18090a8260cSergawy                 << oldSymName;
18190a8260cSergawy             return nullptr;
18290a8260cSergawy           }
18390a8260cSergawy 
18490a8260cSergawy           symRenameListener(originalModule, oldSymName, newSymName);
18590a8260cSergawy 
18690a8260cSergawy           // Since the symbol name is updated, there is no need to maintain the
18790a8260cSergawy           // entry that assocaites the old symbol name with the original module.
18890a8260cSergawy           symNameToModuleMap.erase(oldSymName);
18990a8260cSergawy           // Instead, add a new entry to map the new symbol name to the original
19090a8260cSergawy           // module in case it gets renamed again later.
19190a8260cSergawy           symNameToModuleMap[newSymName] = originalModule;
19290a8260cSergawy         }
19390a8260cSergawy       }
19490a8260cSergawy     }
19590a8260cSergawy 
19690a8260cSergawy     // In the current input module, rename all symbols that conflict with
19790a8260cSergawy     // symbols from the combined module. This includes renaming spv.funcs.
19890a8260cSergawy     for (auto &op : moduleClone.getBlock().without_terminator()) {
19990a8260cSergawy       if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
20090a8260cSergawy         StringRef oldSymName = symbolOp.getName();
20190a8260cSergawy 
20290a8260cSergawy         if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
20390a8260cSergawy                                           lastUsedID)))
20490a8260cSergawy           return nullptr;
20590a8260cSergawy 
20690a8260cSergawy         StringRef newSymName = symbolOp.getName();
20790a8260cSergawy 
20890a8260cSergawy         if (symRenameListener && oldSymName != newSymName) {
20990a8260cSergawy           symRenameListener(module, oldSymName, newSymName);
21090a8260cSergawy 
21190a8260cSergawy           // Insert the module associated with the symbol name.
21290a8260cSergawy           auto emplaceResult =
21390a8260cSergawy               symNameToModuleMap.try_emplace(symbolOp.getName(), module);
21490a8260cSergawy 
21590a8260cSergawy           // If an entry with the same symbol name is already present, this must
21690a8260cSergawy           // be a problem with the implementation, specially clean-up of the map
21790a8260cSergawy           // while iterating over the combined module above.
21890a8260cSergawy           if (!emplaceResult.second) {
21990a8260cSergawy             module.emitError("did not expect to find an entry for symbol ")
22090a8260cSergawy                 << symbolOp.getName();
22190a8260cSergawy             return nullptr;
22290a8260cSergawy           }
22390a8260cSergawy         }
22490a8260cSergawy       }
22590a8260cSergawy     }
22690a8260cSergawy 
22790a8260cSergawy     // Clone all the module's ops to the combined module.
22890a8260cSergawy     for (auto &op : moduleClone.getBlock().without_terminator())
22990a8260cSergawy       combinedModuleBuilder.insert(op.clone());
23090a8260cSergawy   }
23190a8260cSergawy 
232*341f3c11Sergawy   // Deduplicate identical global variables, spec constants, and functions.
233*341f3c11Sergawy   DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
234*341f3c11Sergawy   SmallVector<SymbolOpInterface, 0> eraseList;
235*341f3c11Sergawy 
236*341f3c11Sergawy   for (auto &op : combinedModule.getBlock().without_terminator()) {
237*341f3c11Sergawy     llvm::hash_code hashCode(0);
238*341f3c11Sergawy     SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
239*341f3c11Sergawy 
240*341f3c11Sergawy     if (!symbolOp)
241*341f3c11Sergawy       continue;
242*341f3c11Sergawy 
243*341f3c11Sergawy     hashCode = computeHash(symbolOp);
244*341f3c11Sergawy 
245*341f3c11Sergawy     // A 0 hash code means the op is not suitable for deduplication and should
246*341f3c11Sergawy     // be skipped. An example of this is when a function has ops with regions
247*341f3c11Sergawy     // which are not properly supported yet.
248*341f3c11Sergawy     if (!hashCode)
249*341f3c11Sergawy       continue;
250*341f3c11Sergawy 
251*341f3c11Sergawy     if (auto funcOp = dyn_cast<FuncOp>(op))
252*341f3c11Sergawy       for (auto &blk : funcOp)
253*341f3c11Sergawy         hashCode = llvm::hash_combine(hashCode, computeHash(&blk));
254*341f3c11Sergawy 
255*341f3c11Sergawy     SymbolOpInterface replacementSymOp =
256*341f3c11Sergawy         emplaceOrGetReplacementSymbol(hashCode, symbolOp, hashToSymbolOp);
257*341f3c11Sergawy 
258*341f3c11Sergawy     if (!replacementSymOp)
259*341f3c11Sergawy       continue;
260*341f3c11Sergawy 
261*341f3c11Sergawy     if (failed(SymbolTable::replaceAllSymbolUses(
262*341f3c11Sergawy             symbolOp, replacementSymOp.getName(), combinedModule))) {
263*341f3c11Sergawy       symbolOp.emitError("unable to update all symbol uses for ")
264*341f3c11Sergawy           << symbolOp.getName() << " to " << replacementSymOp.getName();
265*341f3c11Sergawy       return nullptr;
266*341f3c11Sergawy     }
267*341f3c11Sergawy 
268*341f3c11Sergawy     eraseList.push_back(symbolOp);
269*341f3c11Sergawy   }
270*341f3c11Sergawy 
271*341f3c11Sergawy   for (auto symbolOp : eraseList)
272*341f3c11Sergawy     symbolOp.erase();
273*341f3c11Sergawy 
27490a8260cSergawy   return combinedModule;
27590a8260cSergawy }
27690a8260cSergawy 
27790a8260cSergawy } // namespace spirv
27890a8260cSergawy } // namespace mlir
279