1 //===- ModuleCombiner.cpp - MLIR SPIR-V Module Combiner ---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the the SPIR-V module combiner library.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/ModuleCombiner.h"
14 
15 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/SymbolTable.h"
18 #include "llvm/ADT/ArrayRef.h"
19 #include "llvm/ADT/StringExtras.h"
20 
21 using namespace mlir;
22 
23 static constexpr unsigned maxFreeID = 1 << 20;
24 
25 static SmallString<64> renameSymbol(StringRef oldSymName, unsigned &lastUsedID,
26                                     spirv::ModuleOp combinedModule) {
27   SmallString<64> newSymName(oldSymName);
28   newSymName.push_back('_');
29 
30   while (lastUsedID < maxFreeID) {
31     std::string possible = (newSymName + llvm::utostr(++lastUsedID)).str();
32 
33     if (!SymbolTable::lookupSymbolIn(combinedModule, possible)) {
34       newSymName += llvm::utostr(lastUsedID);
35       break;
36     }
37   }
38 
39   return newSymName;
40 }
41 
42 /// Check if a symbol with the same name as op already exists in source. If so,
43 /// rename op and update all its references in target.
44 static LogicalResult updateSymbolAndAllUses(SymbolOpInterface op,
45                                             spirv::ModuleOp target,
46                                             spirv::ModuleOp source,
47                                             unsigned &lastUsedID) {
48   if (!SymbolTable::lookupSymbolIn(source, op.getName()))
49     return success();
50 
51   StringRef oldSymName = op.getName();
52   SmallString<64> newSymName = renameSymbol(oldSymName, lastUsedID, target);
53 
54   if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
55     return op.emitError("unable to update all symbol uses for ")
56            << oldSymName << " to " << newSymName;
57 
58   SymbolTable::setSymbolName(op, newSymName);
59   return success();
60 }
61 
62 namespace mlir {
63 namespace spirv {
64 
65 // TODO Properly test symbol rename listener mechanism.
66 
67 OwningSPIRVModuleRef
68 combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
69         OpBuilder &combinedModuleBuilder,
70         llvm::function_ref<void(ModuleOp, StringRef, StringRef)>
71             symRenameListener) {
72   unsigned lastUsedID = 0;
73 
74   if (modules.empty())
75     return nullptr;
76 
77   auto addressingModel = modules[0].addressing_model();
78   auto memoryModel = modules[0].memory_model();
79 
80   auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
81       modules[0].getLoc(), addressingModel, memoryModel);
82   combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
83 
84   // In some cases, a symbol in the (current state of the) combined module is
85   // renamed in order to maintain the conflicting symbol in the input module
86   // being merged. For example, if the conflict is between a global variable in
87   // the current combined module and a function in the input module, the global
88   // varaible is renamed. In order to notify listeners of the symbol updates in
89   // such cases, we need to keep track of the module from which the renamed
90   // symbol in the combined module originated. This map keeps such information.
91   DenseMap<StringRef, spirv::ModuleOp> symNameToModuleMap;
92 
93   for (auto module : modules) {
94     if (module.addressing_model() != addressingModel ||
95         module.memory_model() != memoryModel) {
96       module.emitError(
97           "input modules differ in addressing model and/or memory model");
98       return nullptr;
99     }
100 
101     spirv::ModuleOp moduleClone = module.clone();
102 
103     // In the combined module, rename all symbols that conflict with symbols
104     // from the current input module. This renmaing applies to all ops except
105     // for spv.funcs. This way, if the conflicting op in the input module is
106     // non-spv.func, we rename that symbol instead and maintain the spv.func in
107     // the combined module name as it is.
108     for (auto &op : combinedModule.getBlock().without_terminator()) {
109       auto symbolOp = dyn_cast<SymbolOpInterface>(op);
110       if (!symbolOp)
111         continue;
112 
113       StringRef oldSymName = symbolOp.getName();
114 
115       if (!isa<FuncOp>(op) &&
116           failed(updateSymbolAndAllUses(symbolOp, combinedModule, moduleClone,
117                                         lastUsedID)))
118         return nullptr;
119 
120       StringRef newSymName = symbolOp.getName();
121 
122       if (symRenameListener && oldSymName != newSymName) {
123         spirv::ModuleOp originalModule = symNameToModuleMap.lookup(oldSymName);
124 
125         if (!originalModule) {
126           module.emitError("unable to find original ModuleOp for symbol ")
127               << oldSymName;
128           return nullptr;
129         }
130 
131         symRenameListener(originalModule, oldSymName, newSymName);
132 
133         // Since the symbol name is updated, there is no need to maintain the
134         // entry that assocaites the old symbol name with the original module.
135         symNameToModuleMap.erase(oldSymName);
136         // Instead, add a new entry to map the new symbol name to the original
137         // module in case it gets renamed again later.
138         symNameToModuleMap[newSymName] = originalModule;
139       }
140     }
141 
142     // In the current input module, rename all symbols that conflict with
143     // symbols from the combined module. This includes renaming spv.funcs.
144     for (auto &op : moduleClone.getBlock().without_terminator()) {
145       auto symbolOp = dyn_cast<SymbolOpInterface>(op);
146       if (!symbolOp)
147         continue;
148 
149       StringRef oldSymName = symbolOp.getName();
150 
151       if (failed(updateSymbolAndAllUses(symbolOp, moduleClone, combinedModule,
152                                         lastUsedID)))
153         return nullptr;
154 
155       StringRef newSymName = symbolOp.getName();
156 
157       if (symRenameListener && oldSymName != newSymName) {
158         symRenameListener(module, oldSymName, newSymName);
159 
160         // Insert the module associated with the symbol name.
161         auto emplaceResult = symNameToModuleMap.try_emplace(newSymName, module);
162 
163         // If an entry with the same symbol name is already present, this must
164         // be a problem with the implementation, specially clean-up of the map
165         // while iterating over the combined module above.
166         if (!emplaceResult.second) {
167           module.emitError("did not expect to find an entry for symbol ")
168               << newSymName;
169           return nullptr;
170         }
171       }
172     }
173 
174     // Clone all the module's ops to the combined module.
175     for (auto &op : moduleClone.getBlock().without_terminator())
176       combinedModuleBuilder.insert(op.clone());
177   }
178 
179   return combinedModule;
180 }
181 
182 } // namespace spirv
183 } // namespace mlir
184