15a440378SAlex Zinenko //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===// 25a440378SAlex Zinenko // 35a440378SAlex Zinenko // Copyright 2019 The MLIR Authors. 45a440378SAlex Zinenko // 55a440378SAlex Zinenko // Licensed under the Apache License, Version 2.0 (the "License"); 65a440378SAlex Zinenko // you may not use this file except in compliance with the License. 75a440378SAlex Zinenko // You may obtain a copy of the License at 85a440378SAlex Zinenko // 95a440378SAlex Zinenko // http://www.apache.org/licenses/LICENSE-2.0 105a440378SAlex Zinenko // 115a440378SAlex Zinenko // Unless required by applicable law or agreed to in writing, software 125a440378SAlex Zinenko // distributed under the License is distributed on an "AS IS" BASIS, 135a440378SAlex Zinenko // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 145a440378SAlex Zinenko // See the License for the specific language governing permissions and 155a440378SAlex Zinenko // limitations under the License. 165a440378SAlex Zinenko // ============================================================================= 175a440378SAlex Zinenko // 185a440378SAlex Zinenko // This file implements the execution engine for MLIR modules based on LLVM Orc 195a440378SAlex Zinenko // JIT engine. 205a440378SAlex Zinenko // 215a440378SAlex Zinenko //===----------------------------------------------------------------------===// 225a440378SAlex Zinenko #include "mlir/ExecutionEngine/ExecutionEngine.h" 235a440378SAlex Zinenko #include "mlir/IR/Function.h" 245a440378SAlex Zinenko #include "mlir/IR/Module.h" 2506e81010SJacques Pienaar #include "mlir/Support/FileUtilities.h" 265a440378SAlex Zinenko #include "mlir/Target/LLVMIR.h" 275a440378SAlex Zinenko 28fe3594f7SNicolas Vasilache #include "llvm/Bitcode/BitcodeReader.h" 29fe3594f7SNicolas Vasilache #include "llvm/Bitcode/BitcodeWriter.h" 30fe3594f7SNicolas Vasilache #include "llvm/ExecutionEngine/ObjectCache.h" 315a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/CompileUtils.h" 325a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" 335a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" 348093f17aSAlex Zinenko #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" 355a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 365a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" 375a440378SAlex Zinenko #include "llvm/ExecutionEngine/SectionMemoryManager.h" 385a440378SAlex Zinenko #include "llvm/IR/IRBuilder.h" 395a440378SAlex Zinenko #include "llvm/Support/Error.h" 405a440378SAlex Zinenko #include "llvm/Support/TargetRegistry.h" 4106e81010SJacques Pienaar #include "llvm/Support/ToolOutputFile.h" 425a440378SAlex Zinenko 435a440378SAlex Zinenko using namespace mlir; 44fe3594f7SNicolas Vasilache using llvm::dbgs; 455a440378SAlex Zinenko using llvm::Error; 46fe3594f7SNicolas Vasilache using llvm::errs; 475a440378SAlex Zinenko using llvm::Expected; 48fe3594f7SNicolas Vasilache using llvm::LLVMContext; 49fe3594f7SNicolas Vasilache using llvm::MemoryBuffer; 50fe3594f7SNicolas Vasilache using llvm::MemoryBufferRef; 51fe3594f7SNicolas Vasilache using llvm::Module; 52fe3594f7SNicolas Vasilache using llvm::SectionMemoryManager; 53fe3594f7SNicolas Vasilache using llvm::StringError; 54fe3594f7SNicolas Vasilache using llvm::Triple; 55fe3594f7SNicolas Vasilache using llvm::orc::DynamicLibrarySearchGenerator; 56fe3594f7SNicolas Vasilache using llvm::orc::ExecutionSession; 57fe3594f7SNicolas Vasilache using llvm::orc::IRCompileLayer; 58fe3594f7SNicolas Vasilache using llvm::orc::JITTargetMachineBuilder; 59fe3594f7SNicolas Vasilache using llvm::orc::RTDyldObjectLinkingLayer; 60fe3594f7SNicolas Vasilache using llvm::orc::ThreadSafeModule; 61fe3594f7SNicolas Vasilache using llvm::orc::TMOwningSimpleCompiler; 626aa5cc8bSNicolas Vasilache 635a440378SAlex Zinenko // Wrap a string into an llvm::StringError. 645a440378SAlex Zinenko static inline Error make_string_error(const llvm::Twine &message) { 65fe3594f7SNicolas Vasilache return llvm::make_error<StringError>(message.str(), 665a440378SAlex Zinenko llvm::inconvertibleErrorCode()); 675a440378SAlex Zinenko } 685a440378SAlex Zinenko 69fe3594f7SNicolas Vasilache namespace mlir { 70fe3594f7SNicolas Vasilache 71fe3594f7SNicolas Vasilache void SimpleObjectCache::notifyObjectCompiled(const Module *M, 72fe3594f7SNicolas Vasilache MemoryBufferRef ObjBuffer) { 7306e81010SJacques Pienaar cachedObjects[M->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy( 74fe3594f7SNicolas Vasilache ObjBuffer.getBuffer(), ObjBuffer.getBufferIdentifier()); 75fe3594f7SNicolas Vasilache } 76fe3594f7SNicolas Vasilache 77fe3594f7SNicolas Vasilache std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *M) { 7806e81010SJacques Pienaar auto I = cachedObjects.find(M->getModuleIdentifier()); 7906e81010SJacques Pienaar if (I == cachedObjects.end()) { 80fe3594f7SNicolas Vasilache dbgs() << "No object for " << M->getModuleIdentifier() 81fe3594f7SNicolas Vasilache << " in cache. Compiling.\n"; 82fe3594f7SNicolas Vasilache return nullptr; 83fe3594f7SNicolas Vasilache } 84fe3594f7SNicolas Vasilache dbgs() << "Object for " << M->getModuleIdentifier() 85fe3594f7SNicolas Vasilache << " loaded from cache.\n"; 86fe3594f7SNicolas Vasilache return MemoryBuffer::getMemBuffer(I->second->getMemBufferRef()); 87fe3594f7SNicolas Vasilache } 88fe3594f7SNicolas Vasilache 8906e81010SJacques Pienaar void SimpleObjectCache::dumpToObjectFile(llvm::StringRef outputFilename) { 9006e81010SJacques Pienaar // Set up the output file. 9106e81010SJacques Pienaar std::string errorMessage; 9206e81010SJacques Pienaar auto file = openOutputFile(outputFilename, &errorMessage); 9306e81010SJacques Pienaar if (!file) { 9406e81010SJacques Pienaar llvm::errs() << errorMessage << "\n"; 9506e81010SJacques Pienaar return; 9606e81010SJacques Pienaar } 9706e81010SJacques Pienaar 9806e81010SJacques Pienaar // Dump the object generated for a single module to the output file. 9906e81010SJacques Pienaar assert(cachedObjects.size() == 1 && "Expected only one object entry."); 10006e81010SJacques Pienaar auto &cachedObject = cachedObjects.begin()->second; 10106e81010SJacques Pienaar file->os() << cachedObject->getBuffer(); 10206e81010SJacques Pienaar file->keep(); 10306e81010SJacques Pienaar } 10406e81010SJacques Pienaar 10506e81010SJacques Pienaar void ExecutionEngine::dumpToObjectFile(llvm::StringRef filename) { 10606e81010SJacques Pienaar cache->dumpToObjectFile(filename); 10706e81010SJacques Pienaar } 10806e81010SJacques Pienaar 1095a440378SAlex Zinenko // Setup LLVM target triple from the current machine. 110fe3594f7SNicolas Vasilache bool ExecutionEngine::setupTargetTriple(Module *llvmModule) { 1115a440378SAlex Zinenko // Setup the machine properties from the current architecture. 1125a440378SAlex Zinenko auto targetTriple = llvm::sys::getDefaultTargetTriple(); 1135a440378SAlex Zinenko std::string errorMessage; 1145a440378SAlex Zinenko auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage); 1155a440378SAlex Zinenko if (!target) { 116fe3594f7SNicolas Vasilache errs() << "NO target: " << errorMessage << "\n"; 1175a440378SAlex Zinenko return true; 1185a440378SAlex Zinenko } 1195a440378SAlex Zinenko auto machine = 1205a440378SAlex Zinenko target->createTargetMachine(targetTriple, "generic", "", {}, {}); 1215a440378SAlex Zinenko llvmModule->setDataLayout(machine->createDataLayout()); 1225a440378SAlex Zinenko llvmModule->setTargetTriple(targetTriple); 1235a440378SAlex Zinenko return false; 1245a440378SAlex Zinenko } 1255a440378SAlex Zinenko 1265a440378SAlex Zinenko static std::string makePackedFunctionName(StringRef name) { 1275a440378SAlex Zinenko return "_mlir_" + name.str(); 1285a440378SAlex Zinenko } 1295a440378SAlex Zinenko 1305a440378SAlex Zinenko // For each function in the LLVM module, define an interface function that wraps 1315a440378SAlex Zinenko // all the arguments of the original function and all its results into an i8** 1325a440378SAlex Zinenko // pointer to provide a unified invocation interface. 133fe3594f7SNicolas Vasilache void packFunctionArguments(Module *module) { 1345a440378SAlex Zinenko auto &ctx = module->getContext(); 1355a440378SAlex Zinenko llvm::IRBuilder<> builder(ctx); 1365a440378SAlex Zinenko llvm::DenseSet<llvm::Function *> interfaceFunctions; 1375a440378SAlex Zinenko for (auto &func : module->getFunctionList()) { 1385a440378SAlex Zinenko if (func.isDeclaration()) { 1395a440378SAlex Zinenko continue; 1405a440378SAlex Zinenko } 1415a440378SAlex Zinenko if (interfaceFunctions.count(&func)) { 1425a440378SAlex Zinenko continue; 1435a440378SAlex Zinenko } 1445a440378SAlex Zinenko 1455a440378SAlex Zinenko // Given a function `foo(<...>)`, define the interface function 1465a440378SAlex Zinenko // `mlir_foo(i8**)`. 1475a440378SAlex Zinenko auto newType = llvm::FunctionType::get( 1485a440378SAlex Zinenko builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(), 1495a440378SAlex Zinenko /*isVarArg=*/false); 1505a440378SAlex Zinenko auto newName = makePackedFunctionName(func.getName()); 151c46b0feaSRiver Riddle auto funcCst = module->getOrInsertFunction(newName, newType); 152c46b0feaSRiver Riddle llvm::Function *interfaceFunc = 153c46b0feaSRiver Riddle llvm::cast<llvm::Function>(funcCst.getCallee()); 1545a440378SAlex Zinenko interfaceFunctions.insert(interfaceFunc); 1555a440378SAlex Zinenko 1565a440378SAlex Zinenko // Extract the arguments from the type-erased argument list and cast them to 1575a440378SAlex Zinenko // the proper types. 1585a440378SAlex Zinenko auto bb = llvm::BasicBlock::Create(ctx); 1595a440378SAlex Zinenko bb->insertInto(interfaceFunc); 1605a440378SAlex Zinenko builder.SetInsertPoint(bb); 1615a440378SAlex Zinenko llvm::Value *argList = interfaceFunc->arg_begin(); 1625a440378SAlex Zinenko llvm::SmallVector<llvm::Value *, 8> args; 1635a440378SAlex Zinenko args.reserve(llvm::size(func.args())); 1645a440378SAlex Zinenko for (auto &indexedArg : llvm::enumerate(func.args())) { 1655a440378SAlex Zinenko llvm::Value *argIndex = llvm::Constant::getIntegerValue( 1665a440378SAlex Zinenko builder.getInt64Ty(), llvm::APInt(64, indexedArg.index())); 1675a440378SAlex Zinenko llvm::Value *argPtrPtr = builder.CreateGEP(argList, argIndex); 1685a440378SAlex Zinenko llvm::Value *argPtr = builder.CreateLoad(argPtrPtr); 1695a440378SAlex Zinenko argPtr = builder.CreateBitCast( 1705a440378SAlex Zinenko argPtr, indexedArg.value().getType()->getPointerTo()); 1715a440378SAlex Zinenko llvm::Value *arg = builder.CreateLoad(argPtr); 1725a440378SAlex Zinenko args.push_back(arg); 1735a440378SAlex Zinenko } 1745a440378SAlex Zinenko 1755a440378SAlex Zinenko // Call the implementation function with the extracted arguments. 1765a440378SAlex Zinenko llvm::Value *result = builder.CreateCall(&func, args); 1775a440378SAlex Zinenko 1785a440378SAlex Zinenko // Assuming the result is one value, potentially of type `void`. 1795a440378SAlex Zinenko if (!result->getType()->isVoidTy()) { 1805a440378SAlex Zinenko llvm::Value *retIndex = llvm::Constant::getIntegerValue( 1815a440378SAlex Zinenko builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args()))); 1825a440378SAlex Zinenko llvm::Value *retPtrPtr = builder.CreateGEP(argList, retIndex); 1835a440378SAlex Zinenko llvm::Value *retPtr = builder.CreateLoad(retPtrPtr); 1845a440378SAlex Zinenko retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo()); 1855a440378SAlex Zinenko builder.CreateStore(result, retPtr); 1865a440378SAlex Zinenko } 1875a440378SAlex Zinenko 1885a440378SAlex Zinenko // The interface function returns void. 1895a440378SAlex Zinenko builder.CreateRetVoid(); 1905a440378SAlex Zinenko } 1915a440378SAlex Zinenko } 1925a440378SAlex Zinenko 19306e81010SJacques Pienaar ExecutionEngine::ExecutionEngine(bool enableObjectCache) 19406e81010SJacques Pienaar : cache(enableObjectCache ? nullptr : new SimpleObjectCache()) {} 19506e81010SJacques Pienaar 19606e81010SJacques Pienaar Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create( 19706e81010SJacques Pienaar ModuleOp m, std::function<Error(llvm::Module *)> transformer, 19806e81010SJacques Pienaar ArrayRef<StringRef> sharedLibPaths, bool enableObjectCache) { 19906e81010SJacques Pienaar auto engine = std::make_unique<ExecutionEngine>(enableObjectCache); 2005a440378SAlex Zinenko 201fe3594f7SNicolas Vasilache std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext); 202206e55ccSRiver Riddle auto llvmModule = translateModuleToLLVMIR(m); 2035a440378SAlex Zinenko if (!llvmModule) 2045a440378SAlex Zinenko return make_string_error("could not convert to LLVM IR"); 2055a440378SAlex Zinenko // FIXME: the triple should be passed to the translation or dialect conversion 2065a440378SAlex Zinenko // instead of this. Currently, the LLVM module created above has no triple 2075a440378SAlex Zinenko // associated with it. 2085a440378SAlex Zinenko setupTargetTriple(llvmModule.get()); 2095a440378SAlex Zinenko packFunctionArguments(llvmModule.get()); 2105a440378SAlex Zinenko 211fe3594f7SNicolas Vasilache // Clone module in a new LLVMContext since translateModuleToLLVMIR buries 212fe3594f7SNicolas Vasilache // ownership too deeply. 213fe3594f7SNicolas Vasilache // TODO(zinenko): Reevaluate model of ownership of LLVMContext in LLVMDialect. 214fe3594f7SNicolas Vasilache SmallVector<char, 1> buffer; 215fe3594f7SNicolas Vasilache { 216fe3594f7SNicolas Vasilache llvm::raw_svector_ostream os(buffer); 217fe3594f7SNicolas Vasilache WriteBitcodeToFile(*llvmModule, os); 218fe3594f7SNicolas Vasilache } 219fe3594f7SNicolas Vasilache llvm::MemoryBufferRef bufferRef(llvm::StringRef(buffer.data(), buffer.size()), 220fe3594f7SNicolas Vasilache "cloned module buffer"); 221fe3594f7SNicolas Vasilache auto expectedModule = parseBitcodeFile(bufferRef, *ctx); 222fe3594f7SNicolas Vasilache if (!expectedModule) 223fe3594f7SNicolas Vasilache return expectedModule.takeError(); 224fe3594f7SNicolas Vasilache std::unique_ptr<Module> deserModule = std::move(*expectedModule); 225fe3594f7SNicolas Vasilache 226fe3594f7SNicolas Vasilache // Callback to create the object layer with symbol resolution to current 227fe3594f7SNicolas Vasilache // process and dynamically linked libraries. 228fe3594f7SNicolas Vasilache auto objectLinkingLayerCreator = [&](ExecutionSession &session, 229fe3594f7SNicolas Vasilache const Triple &TT) { 230fe3594f7SNicolas Vasilache auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>( 231fe3594f7SNicolas Vasilache session, []() { return std::make_unique<SectionMemoryManager>(); }); 232fe3594f7SNicolas Vasilache auto dataLayout = deserModule->getDataLayout(); 233fe3594f7SNicolas Vasilache 234fe3594f7SNicolas Vasilache // Resolve symbols that are statically linked in the current process. 235fe3594f7SNicolas Vasilache session.getMainJITDylib().addGenerator( 236fe3594f7SNicolas Vasilache cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess( 237fe3594f7SNicolas Vasilache dataLayout.getGlobalPrefix()))); 238fe3594f7SNicolas Vasilache 239fe3594f7SNicolas Vasilache // Resolve symbols from shared libraries. 240fe3594f7SNicolas Vasilache for (auto libPath : sharedLibPaths) { 241fe3594f7SNicolas Vasilache auto mb = llvm::MemoryBuffer::getFile(libPath); 242fe3594f7SNicolas Vasilache if (!mb) { 243fe3594f7SNicolas Vasilache errs() << "Fail to create MemoryBuffer for: " << libPath << "\n"; 244fe3594f7SNicolas Vasilache continue; 245fe3594f7SNicolas Vasilache } 246fe3594f7SNicolas Vasilache auto &JD = session.createJITDylib(libPath); 247fe3594f7SNicolas Vasilache auto loaded = DynamicLibrarySearchGenerator::Load( 248fe3594f7SNicolas Vasilache libPath.data(), dataLayout.getGlobalPrefix()); 249fe3594f7SNicolas Vasilache if (!loaded) { 250fe3594f7SNicolas Vasilache errs() << "Could not load: " << libPath << "\n"; 251fe3594f7SNicolas Vasilache continue; 252fe3594f7SNicolas Vasilache } 253fe3594f7SNicolas Vasilache JD.addGenerator(std::move(*loaded)); 254fe3594f7SNicolas Vasilache cantFail(objectLayer->add(JD, std::move(mb.get()))); 255fe3594f7SNicolas Vasilache } 256fe3594f7SNicolas Vasilache 257fe3594f7SNicolas Vasilache return objectLayer; 258fe3594f7SNicolas Vasilache }; 259fe3594f7SNicolas Vasilache 260fe3594f7SNicolas Vasilache // Callback to inspect the cache and recompile on demand. This follows Lang's 261fe3594f7SNicolas Vasilache // LLJITWithObjectCache example. 262fe3594f7SNicolas Vasilache auto compileFunctionCreator = [&](JITTargetMachineBuilder JTMB) 263fe3594f7SNicolas Vasilache -> Expected<IRCompileLayer::CompileFunction> { 264fe3594f7SNicolas Vasilache auto TM = JTMB.createTargetMachine(); 265fe3594f7SNicolas Vasilache if (!TM) 266fe3594f7SNicolas Vasilache return TM.takeError(); 267fe3594f7SNicolas Vasilache return IRCompileLayer::CompileFunction( 268fe3594f7SNicolas Vasilache TMOwningSimpleCompiler(std::move(*TM), engine->cache.get())); 269fe3594f7SNicolas Vasilache }; 270fe3594f7SNicolas Vasilache 271fe3594f7SNicolas Vasilache // Create the LLJIT by calling the LLJITBuilder with 2 callbacks. 272fe3594f7SNicolas Vasilache auto jit = 273fe3594f7SNicolas Vasilache cantFail(llvm::orc::LLJITBuilder() 274fe3594f7SNicolas Vasilache .setCompileFunctionCreator(compileFunctionCreator) 275fe3594f7SNicolas Vasilache .setObjectLinkingLayerCreator(objectLinkingLayerCreator) 276fe3594f7SNicolas Vasilache .create()); 277fe3594f7SNicolas Vasilache 278fe3594f7SNicolas Vasilache // Add a ThreadSafemodule to the engine and return. 279fe3594f7SNicolas Vasilache ThreadSafeModule tsm(std::move(deserModule), std::move(ctx)); 280*cf26e5faSNicolas Vasilache if (transformer) 281*cf26e5faSNicolas Vasilache cantFail(tsm.withModuleDo( 282*cf26e5faSNicolas Vasilache [&](llvm::Module &module) { return transformer(&module); })); 283fe3594f7SNicolas Vasilache cantFail(jit->addIRModule(std::move(tsm))); 284fe3594f7SNicolas Vasilache engine->jit = std::move(jit); 2855a440378SAlex Zinenko 286e7111fd6SJacques Pienaar return std::move(engine); 2875a440378SAlex Zinenko } 2885a440378SAlex Zinenko 2895a440378SAlex Zinenko Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const { 2905a440378SAlex Zinenko auto expectedSymbol = jit->lookup(makePackedFunctionName(name)); 2915a440378SAlex Zinenko if (!expectedSymbol) 2925a440378SAlex Zinenko return expectedSymbol.takeError(); 2935a440378SAlex Zinenko auto rawFPtr = expectedSymbol->getAddress(); 2945a440378SAlex Zinenko auto fptr = reinterpret_cast<void (*)(void **)>(rawFPtr); 2955a440378SAlex Zinenko if (!fptr) 2965a440378SAlex Zinenko return make_string_error("looked up function is null"); 2975a440378SAlex Zinenko return fptr; 2985a440378SAlex Zinenko } 299629f5b7fSNicolas Vasilache 300fe3594f7SNicolas Vasilache Error ExecutionEngine::invoke(StringRef name, MutableArrayRef<void *> args) { 301629f5b7fSNicolas Vasilache auto expectedFPtr = lookup(name); 302629f5b7fSNicolas Vasilache if (!expectedFPtr) 303629f5b7fSNicolas Vasilache return expectedFPtr.takeError(); 304629f5b7fSNicolas Vasilache auto fptr = *expectedFPtr; 305629f5b7fSNicolas Vasilache 306629f5b7fSNicolas Vasilache (*fptr)(args.data()); 307629f5b7fSNicolas Vasilache 308fe3594f7SNicolas Vasilache return Error::success(); 309629f5b7fSNicolas Vasilache } 310fe3594f7SNicolas Vasilache } // end namespace mlir 311