1 //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the execution engine for MLIR modules based on LLVM Orc 19 // JIT engine. 20 // 21 //===----------------------------------------------------------------------===// 22 #include "mlir/ExecutionEngine/ExecutionEngine.h" 23 #include "mlir/IR/Function.h" 24 #include "mlir/IR/Module.h" 25 #include "mlir/Target/LLVMIR.h" 26 27 #include "llvm/ExecutionEngine/Orc/CompileUtils.h" 28 #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" 29 #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" 30 #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" 31 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 32 #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" 33 #include "llvm/ExecutionEngine/SectionMemoryManager.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/Support/Error.h" 36 #include "llvm/Support/TargetRegistry.h" 37 38 using namespace mlir; 39 using llvm::Error; 40 using llvm::Expected; 41 42 namespace { 43 // Memory manager for the JIT's objectLayer. Its main goal is to fallback to 44 // resolving functions in the current process if they cannot be resolved in the 45 // JIT-compiled modules. 46 class MemoryManager : public llvm::SectionMemoryManager { 47 public: 48 MemoryManager(llvm::orc::ExecutionSession &execSession) 49 : session(execSession) {} 50 51 // Resolve the named symbol. First, try looking it up in the main library of 52 // the execution session. If there is no such symbol, try looking it up in 53 // the current process (for example, if it is a standard library function). 54 // Return `nullptr` if lookup fails. 55 llvm::JITSymbol findSymbol(const std::string &name) override { 56 auto mainLibSymbol = session.lookup({&session.getMainJITDylib()}, name); 57 if (mainLibSymbol) 58 return mainLibSymbol.get(); 59 auto address = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name); 60 if (!address) { 61 llvm::errs() << "Could not look up: " << name << '\n'; 62 return nullptr; 63 } 64 return llvm::JITSymbol(address, llvm::JITSymbolFlags::Exported); 65 } 66 67 private: 68 llvm::orc::ExecutionSession &session; 69 }; 70 } // end anonymous namespace 71 72 namespace mlir { 73 namespace impl { 74 75 /// Wrapper class around DynamicLibrarySearchGenerator to allow searching 76 /// in-process symbols that have not been explicitly exported. 77 /// This first tries to resolve a symbol by using DynamicLibrarySearchGenerator. 78 /// For symbols that are not found this way, it then uses 79 /// `llvm::sys::DynamicLibrary::SearchForAddressOfSymbol` to extract symbols 80 /// that have been explicitly added with `llvm::sys::DynamicLibrary::AddSymbol`, 81 /// previously. 82 class SearchGenerator { 83 public: 84 SearchGenerator(char GlobalPrefix) 85 : defaultGenerator(cantFail( 86 llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( 87 GlobalPrefix))) {} 88 89 // This function forwards to DynamicLibrarySearchGenerator::operator() and 90 // adds an extra resolution for names explicitly registered via 91 // `llvm::sys::DynamicLibrary::AddSymbol`. 92 Expected<llvm::orc::SymbolNameSet> 93 operator()(llvm::orc::JITDylib &JD, const llvm::orc::SymbolNameSet &Names) { 94 auto res = defaultGenerator->tryToGenerate(JD, Names); 95 if (!res) 96 return res; 97 llvm::orc::SymbolMap newSymbols; 98 for (auto &Name : Names) { 99 if (res.get().count(Name) > 0) 100 continue; 101 res.get().insert(Name); 102 auto addedSymbolAddress = 103 llvm::sys::DynamicLibrary::SearchForAddressOfSymbol(*Name); 104 if (!addedSymbolAddress) 105 continue; 106 llvm::JITEvaluatedSymbol Sym( 107 reinterpret_cast<uintptr_t>(addedSymbolAddress), 108 llvm::JITSymbolFlags::Exported); 109 newSymbols[Name] = Sym; 110 } 111 if (!newSymbols.empty()) 112 cantFail(JD.define(absoluteSymbols(std::move(newSymbols)))); 113 return res; 114 } 115 116 private: 117 std::unique_ptr<llvm::orc::DynamicLibrarySearchGenerator> defaultGenerator; 118 }; 119 120 // Simple layered Orc JIT compilation engine. 121 class OrcJIT { 122 public: 123 using IRTransformer = std::function<Error(llvm::Module *)>; 124 125 // Construct a JIT engine for the target host defined by `machineBuilder`, 126 // using the data layout provided as `dataLayout`. 127 // Setup the object layer to use our custom memory manager in order to 128 // resolve calls to library functions present in the process. 129 OrcJIT(llvm::orc::JITTargetMachineBuilder machineBuilder, 130 llvm::DataLayout layout, IRTransformer transform, 131 ArrayRef<StringRef> sharedLibPaths) 132 : irTransformer(transform), 133 objectLayer( 134 session, 135 [this]() { return std::make_unique<MemoryManager>(session); }), 136 compileLayer( 137 session, objectLayer, 138 llvm::orc::ConcurrentIRCompiler(std::move(machineBuilder))), 139 transformLayer(session, compileLayer, makeIRTransformFunction()), 140 dataLayout(layout), mangler(session, this->dataLayout), 141 threadSafeCtx(std::make_unique<llvm::LLVMContext>()) { 142 session.getMainJITDylib().addGenerator( 143 cantFail(llvm::orc::DynamicLibrarySearchGenerator::GetForCurrentProcess( 144 layout.getGlobalPrefix()))); 145 loadLibraries(sharedLibPaths); 146 } 147 148 // Create a JIT engine for the current host. 149 static Expected<std::unique_ptr<OrcJIT>> 150 createDefault(IRTransformer transformer, ArrayRef<StringRef> sharedLibPaths) { 151 auto machineBuilder = llvm::orc::JITTargetMachineBuilder::detectHost(); 152 if (!machineBuilder) 153 return machineBuilder.takeError(); 154 155 auto dataLayout = machineBuilder->getDefaultDataLayoutForTarget(); 156 if (!dataLayout) 157 return dataLayout.takeError(); 158 159 return std::make_unique<OrcJIT>(std::move(*machineBuilder), 160 std::move(*dataLayout), transformer, 161 sharedLibPaths); 162 } 163 164 // Add an LLVM module to the main library managed by the JIT engine. 165 Error addModule(std::unique_ptr<llvm::Module> M) { 166 return transformLayer.add( 167 session.getMainJITDylib(), 168 llvm::orc::ThreadSafeModule(std::move(M), threadSafeCtx)); 169 } 170 171 // Lookup a symbol in the main library managed by the JIT engine. 172 Expected<llvm::JITEvaluatedSymbol> lookup(StringRef Name) { 173 return session.lookup({&session.getMainJITDylib()}, mangler(Name.str())); 174 } 175 176 private: 177 // Wrap the `irTransformer` into a function that can be called by the 178 // IRTranformLayer. If `irTransformer` is not set up, return the module as 179 // is without errors. 180 llvm::orc::IRTransformLayer::TransformFunction makeIRTransformFunction() { 181 return [this](llvm::orc::ThreadSafeModule module, 182 const llvm::orc::MaterializationResponsibility &resp) 183 -> Expected<llvm::orc::ThreadSafeModule> { 184 (void)resp; 185 if (!irTransformer) 186 return std::move(module); 187 Error err = module.withModuleDo( 188 [this](llvm::Module &module) { return irTransformer(&module); }); 189 if (err) 190 return std::move(err); 191 return std::move(module); 192 }; 193 } 194 195 // Iterate over shareLibPaths and load the corresponding libraries for symbol 196 // resolution. 197 void loadLibraries(ArrayRef<StringRef> sharedLibPaths); 198 199 IRTransformer irTransformer; 200 llvm::orc::ExecutionSession session; 201 llvm::orc::RTDyldObjectLinkingLayer objectLayer; 202 llvm::orc::IRCompileLayer compileLayer; 203 llvm::orc::IRTransformLayer transformLayer; 204 llvm::DataLayout dataLayout; 205 llvm::orc::MangleAndInterner mangler; 206 llvm::orc::ThreadSafeContext threadSafeCtx; 207 }; 208 } // end namespace impl 209 } // namespace mlir 210 211 void mlir::impl::OrcJIT::loadLibraries(ArrayRef<StringRef> sharedLibPaths) { 212 for (auto libPath : sharedLibPaths) { 213 auto mb = llvm::MemoryBuffer::getFile(libPath); 214 if (!mb) { 215 llvm::errs() << "Could not create MemoryBuffer for: " << libPath << " " 216 << mb.getError().message() << "\n"; 217 continue; 218 } 219 auto &JD = session.createJITDylib(libPath); 220 auto loaded = llvm::orc::DynamicLibrarySearchGenerator::Load( 221 libPath.data(), dataLayout.getGlobalPrefix()); 222 if (!loaded) { 223 llvm::errs() << "Could not load: " << libPath << " " << loaded.takeError() 224 << "\n"; 225 continue; 226 } 227 JD.addGenerator(std::move(*loaded)); 228 auto res = objectLayer.add(JD, std::move(mb.get())); 229 if (res) 230 llvm::errs() << "Could not add: " << libPath << " " << res << "\n"; 231 } 232 } 233 234 // Wrap a string into an llvm::StringError. 235 static inline Error make_string_error(const llvm::Twine &message) { 236 return llvm::make_error<llvm::StringError>(message.str(), 237 llvm::inconvertibleErrorCode()); 238 } 239 240 // Setup LLVM target triple from the current machine. 241 bool ExecutionEngine::setupTargetTriple(llvm::Module *llvmModule) { 242 // Setup the machine properties from the current architecture. 243 auto targetTriple = llvm::sys::getDefaultTargetTriple(); 244 std::string errorMessage; 245 auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage); 246 if (!target) { 247 llvm::errs() << "NO target: " << errorMessage << "\n"; 248 return true; 249 } 250 auto machine = 251 target->createTargetMachine(targetTriple, "generic", "", {}, {}); 252 llvmModule->setDataLayout(machine->createDataLayout()); 253 llvmModule->setTargetTriple(targetTriple); 254 return false; 255 } 256 257 static std::string makePackedFunctionName(StringRef name) { 258 return "_mlir_" + name.str(); 259 } 260 261 // For each function in the LLVM module, define an interface function that wraps 262 // all the arguments of the original function and all its results into an i8** 263 // pointer to provide a unified invocation interface. 264 void packFunctionArguments(llvm::Module *module) { 265 auto &ctx = module->getContext(); 266 llvm::IRBuilder<> builder(ctx); 267 llvm::DenseSet<llvm::Function *> interfaceFunctions; 268 for (auto &func : module->getFunctionList()) { 269 if (func.isDeclaration()) { 270 continue; 271 } 272 if (interfaceFunctions.count(&func)) { 273 continue; 274 } 275 276 // Given a function `foo(<...>)`, define the interface function 277 // `mlir_foo(i8**)`. 278 auto newType = llvm::FunctionType::get( 279 builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(), 280 /*isVarArg=*/false); 281 auto newName = makePackedFunctionName(func.getName()); 282 auto funcCst = module->getOrInsertFunction(newName, newType); 283 llvm::Function *interfaceFunc = 284 llvm::cast<llvm::Function>(funcCst.getCallee()); 285 interfaceFunctions.insert(interfaceFunc); 286 287 // Extract the arguments from the type-erased argument list and cast them to 288 // the proper types. 289 auto bb = llvm::BasicBlock::Create(ctx); 290 bb->insertInto(interfaceFunc); 291 builder.SetInsertPoint(bb); 292 llvm::Value *argList = interfaceFunc->arg_begin(); 293 llvm::SmallVector<llvm::Value *, 8> args; 294 args.reserve(llvm::size(func.args())); 295 for (auto &indexedArg : llvm::enumerate(func.args())) { 296 llvm::Value *argIndex = llvm::Constant::getIntegerValue( 297 builder.getInt64Ty(), llvm::APInt(64, indexedArg.index())); 298 llvm::Value *argPtrPtr = builder.CreateGEP(argList, argIndex); 299 llvm::Value *argPtr = builder.CreateLoad(argPtrPtr); 300 argPtr = builder.CreateBitCast( 301 argPtr, indexedArg.value().getType()->getPointerTo()); 302 llvm::Value *arg = builder.CreateLoad(argPtr); 303 args.push_back(arg); 304 } 305 306 // Call the implementation function with the extracted arguments. 307 llvm::Value *result = builder.CreateCall(&func, args); 308 309 // Assuming the result is one value, potentially of type `void`. 310 if (!result->getType()->isVoidTy()) { 311 llvm::Value *retIndex = llvm::Constant::getIntegerValue( 312 builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args()))); 313 llvm::Value *retPtrPtr = builder.CreateGEP(argList, retIndex); 314 llvm::Value *retPtr = builder.CreateLoad(retPtrPtr); 315 retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo()); 316 builder.CreateStore(result, retPtr); 317 } 318 319 // The interface function returns void. 320 builder.CreateRetVoid(); 321 } 322 } 323 324 // Out of line for PIMPL unique_ptr. 325 ExecutionEngine::~ExecutionEngine() = default; 326 327 Expected<std::unique_ptr<ExecutionEngine>> 328 ExecutionEngine::create(ModuleOp m, 329 std::function<llvm::Error(llvm::Module *)> transformer, 330 ArrayRef<StringRef> sharedLibPaths) { 331 auto engine = std::make_unique<ExecutionEngine>(); 332 auto expectedJIT = impl::OrcJIT::createDefault(transformer, sharedLibPaths); 333 if (!expectedJIT) 334 return expectedJIT.takeError(); 335 336 auto llvmModule = translateModuleToLLVMIR(m); 337 if (!llvmModule) 338 return make_string_error("could not convert to LLVM IR"); 339 // FIXME: the triple should be passed to the translation or dialect conversion 340 // instead of this. Currently, the LLVM module created above has no triple 341 // associated with it. 342 setupTargetTriple(llvmModule.get()); 343 packFunctionArguments(llvmModule.get()); 344 345 if (auto err = (*expectedJIT)->addModule(std::move(llvmModule))) 346 return std::move(err); 347 engine->jit = std::move(*expectedJIT); 348 349 return std::move(engine); 350 } 351 352 Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const { 353 auto expectedSymbol = jit->lookup(makePackedFunctionName(name)); 354 if (!expectedSymbol) 355 return expectedSymbol.takeError(); 356 auto rawFPtr = expectedSymbol->getAddress(); 357 auto fptr = reinterpret_cast<void (*)(void **)>(rawFPtr); 358 if (!fptr) 359 return make_string_error("looked up function is null"); 360 return fptr; 361 } 362 363 llvm::Error ExecutionEngine::invoke(StringRef name, 364 MutableArrayRef<void *> args) { 365 auto expectedFPtr = lookup(name); 366 if (!expectedFPtr) 367 return expectedFPtr.takeError(); 368 auto fptr = *expectedFPtr; 369 370 (*fptr)(args.data()); 371 372 return llvm::Error::success(); 373 } 374