15a440378SAlex Zinenko //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===// 25a440378SAlex Zinenko // 35a440378SAlex Zinenko // Copyright 2019 The MLIR Authors. 45a440378SAlex Zinenko // 55a440378SAlex Zinenko // Licensed under the Apache License, Version 2.0 (the "License"); 65a440378SAlex Zinenko // you may not use this file except in compliance with the License. 75a440378SAlex Zinenko // You may obtain a copy of the License at 85a440378SAlex Zinenko // 95a440378SAlex Zinenko // http://www.apache.org/licenses/LICENSE-2.0 105a440378SAlex Zinenko // 115a440378SAlex Zinenko // Unless required by applicable law or agreed to in writing, software 125a440378SAlex Zinenko // distributed under the License is distributed on an "AS IS" BASIS, 135a440378SAlex Zinenko // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 145a440378SAlex Zinenko // See the License for the specific language governing permissions and 155a440378SAlex Zinenko // limitations under the License. 165a440378SAlex Zinenko // ============================================================================= 175a440378SAlex Zinenko // 185a440378SAlex Zinenko // This file implements the execution engine for MLIR modules based on LLVM Orc 195a440378SAlex Zinenko // JIT engine. 205a440378SAlex Zinenko // 215a440378SAlex Zinenko //===----------------------------------------------------------------------===// 225a440378SAlex Zinenko #include "mlir/ExecutionEngine/ExecutionEngine.h" 235a440378SAlex Zinenko #include "mlir/IR/Function.h" 245a440378SAlex Zinenko #include "mlir/IR/Module.h" 2506e81010SJacques Pienaar #include "mlir/Support/FileUtilities.h" 265a440378SAlex Zinenko #include "mlir/Target/LLVMIR.h" 275a440378SAlex Zinenko 28fe3594f7SNicolas Vasilache #include "llvm/Bitcode/BitcodeReader.h" 29fe3594f7SNicolas Vasilache #include "llvm/Bitcode/BitcodeWriter.h" 30fe3594f7SNicolas Vasilache #include "llvm/ExecutionEngine/ObjectCache.h" 315a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/CompileUtils.h" 325a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h" 335a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" 348093f17aSAlex Zinenko #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h" 355a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 365a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" 375a440378SAlex Zinenko #include "llvm/ExecutionEngine/SectionMemoryManager.h" 385a440378SAlex Zinenko #include "llvm/IR/IRBuilder.h" 3953bb528bSMehdi Amini #include "llvm/Support/Debug.h" 405a440378SAlex Zinenko #include "llvm/Support/Error.h" 415a440378SAlex Zinenko #include "llvm/Support/TargetRegistry.h" 4206e81010SJacques Pienaar #include "llvm/Support/ToolOutputFile.h" 435a440378SAlex Zinenko 4453bb528bSMehdi Amini #define DEBUG_TYPE "execution-engine" 4553bb528bSMehdi Amini 465a440378SAlex Zinenko using namespace mlir; 47fe3594f7SNicolas Vasilache using llvm::dbgs; 485a440378SAlex Zinenko using llvm::Error; 49fe3594f7SNicolas Vasilache using llvm::errs; 505a440378SAlex Zinenko using llvm::Expected; 51fe3594f7SNicolas Vasilache using llvm::LLVMContext; 52fe3594f7SNicolas Vasilache using llvm::MemoryBuffer; 53fe3594f7SNicolas Vasilache using llvm::MemoryBufferRef; 54fe3594f7SNicolas Vasilache using llvm::Module; 55fe3594f7SNicolas Vasilache using llvm::SectionMemoryManager; 56fe3594f7SNicolas Vasilache using llvm::StringError; 57fe3594f7SNicolas Vasilache using llvm::Triple; 58fe3594f7SNicolas Vasilache using llvm::orc::DynamicLibrarySearchGenerator; 59fe3594f7SNicolas Vasilache using llvm::orc::ExecutionSession; 60fe3594f7SNicolas Vasilache using llvm::orc::IRCompileLayer; 61fe3594f7SNicolas Vasilache using llvm::orc::JITTargetMachineBuilder; 62fe3594f7SNicolas Vasilache using llvm::orc::RTDyldObjectLinkingLayer; 63fe3594f7SNicolas Vasilache using llvm::orc::ThreadSafeModule; 64fe3594f7SNicolas Vasilache using llvm::orc::TMOwningSimpleCompiler; 656aa5cc8bSNicolas Vasilache 665a440378SAlex Zinenko // Wrap a string into an llvm::StringError. 675a440378SAlex Zinenko static inline Error make_string_error(const llvm::Twine &message) { 68fe3594f7SNicolas Vasilache return llvm::make_error<StringError>(message.str(), 695a440378SAlex Zinenko llvm::inconvertibleErrorCode()); 705a440378SAlex Zinenko } 715a440378SAlex Zinenko 72fe3594f7SNicolas Vasilache namespace mlir { 73fe3594f7SNicolas Vasilache 74fe3594f7SNicolas Vasilache void SimpleObjectCache::notifyObjectCompiled(const Module *M, 75fe3594f7SNicolas Vasilache MemoryBufferRef ObjBuffer) { 7606e81010SJacques Pienaar cachedObjects[M->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy( 77fe3594f7SNicolas Vasilache ObjBuffer.getBuffer(), ObjBuffer.getBufferIdentifier()); 78fe3594f7SNicolas Vasilache } 79fe3594f7SNicolas Vasilache 80fe3594f7SNicolas Vasilache std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *M) { 8106e81010SJacques Pienaar auto I = cachedObjects.find(M->getModuleIdentifier()); 8206e81010SJacques Pienaar if (I == cachedObjects.end()) { 8353bb528bSMehdi Amini LLVM_DEBUG(dbgs() << "No object for " << M->getModuleIdentifier() 8453bb528bSMehdi Amini << " in cache. Compiling.\n"); 85fe3594f7SNicolas Vasilache return nullptr; 86fe3594f7SNicolas Vasilache } 8753bb528bSMehdi Amini LLVM_DEBUG(dbgs() << "Object for " << M->getModuleIdentifier() 8853bb528bSMehdi Amini << " loaded from cache.\n"); 89fe3594f7SNicolas Vasilache return MemoryBuffer::getMemBuffer(I->second->getMemBufferRef()); 90fe3594f7SNicolas Vasilache } 91fe3594f7SNicolas Vasilache 9206e81010SJacques Pienaar void SimpleObjectCache::dumpToObjectFile(llvm::StringRef outputFilename) { 9306e81010SJacques Pienaar // Set up the output file. 9406e81010SJacques Pienaar std::string errorMessage; 9506e81010SJacques Pienaar auto file = openOutputFile(outputFilename, &errorMessage); 9606e81010SJacques Pienaar if (!file) { 9706e81010SJacques Pienaar llvm::errs() << errorMessage << "\n"; 9806e81010SJacques Pienaar return; 9906e81010SJacques Pienaar } 10006e81010SJacques Pienaar 10106e81010SJacques Pienaar // Dump the object generated for a single module to the output file. 10206e81010SJacques Pienaar assert(cachedObjects.size() == 1 && "Expected only one object entry."); 10306e81010SJacques Pienaar auto &cachedObject = cachedObjects.begin()->second; 10406e81010SJacques Pienaar file->os() << cachedObject->getBuffer(); 10506e81010SJacques Pienaar file->keep(); 10606e81010SJacques Pienaar } 10706e81010SJacques Pienaar 10806e81010SJacques Pienaar void ExecutionEngine::dumpToObjectFile(llvm::StringRef filename) { 10906e81010SJacques Pienaar cache->dumpToObjectFile(filename); 11006e81010SJacques Pienaar } 11106e81010SJacques Pienaar 1125a440378SAlex Zinenko // Setup LLVM target triple from the current machine. 113fe3594f7SNicolas Vasilache bool ExecutionEngine::setupTargetTriple(Module *llvmModule) { 1145a440378SAlex Zinenko // Setup the machine properties from the current architecture. 1155a440378SAlex Zinenko auto targetTriple = llvm::sys::getDefaultTargetTriple(); 1165a440378SAlex Zinenko std::string errorMessage; 1175a440378SAlex Zinenko auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage); 1185a440378SAlex Zinenko if (!target) { 119fe3594f7SNicolas Vasilache errs() << "NO target: " << errorMessage << "\n"; 1205a440378SAlex Zinenko return true; 1215a440378SAlex Zinenko } 122*d732aaf2SMLIR Team std::unique_ptr<llvm::TargetMachine> machine( 123*d732aaf2SMLIR Team target->createTargetMachine(targetTriple, "generic", "", {}, {})); 1245a440378SAlex Zinenko llvmModule->setDataLayout(machine->createDataLayout()); 1255a440378SAlex Zinenko llvmModule->setTargetTriple(targetTriple); 1265a440378SAlex Zinenko return false; 1275a440378SAlex Zinenko } 1285a440378SAlex Zinenko 1295a440378SAlex Zinenko static std::string makePackedFunctionName(StringRef name) { 1305a440378SAlex Zinenko return "_mlir_" + name.str(); 1315a440378SAlex Zinenko } 1325a440378SAlex Zinenko 1335a440378SAlex Zinenko // For each function in the LLVM module, define an interface function that wraps 1345a440378SAlex Zinenko // all the arguments of the original function and all its results into an i8** 1355a440378SAlex Zinenko // pointer to provide a unified invocation interface. 136fe3594f7SNicolas Vasilache void packFunctionArguments(Module *module) { 1375a440378SAlex Zinenko auto &ctx = module->getContext(); 1385a440378SAlex Zinenko llvm::IRBuilder<> builder(ctx); 1395a440378SAlex Zinenko llvm::DenseSet<llvm::Function *> interfaceFunctions; 1405a440378SAlex Zinenko for (auto &func : module->getFunctionList()) { 1415a440378SAlex Zinenko if (func.isDeclaration()) { 1425a440378SAlex Zinenko continue; 1435a440378SAlex Zinenko } 1445a440378SAlex Zinenko if (interfaceFunctions.count(&func)) { 1455a440378SAlex Zinenko continue; 1465a440378SAlex Zinenko } 1475a440378SAlex Zinenko 1485a440378SAlex Zinenko // Given a function `foo(<...>)`, define the interface function 1495a440378SAlex Zinenko // `mlir_foo(i8**)`. 1505a440378SAlex Zinenko auto newType = llvm::FunctionType::get( 1515a440378SAlex Zinenko builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(), 1525a440378SAlex Zinenko /*isVarArg=*/false); 1535a440378SAlex Zinenko auto newName = makePackedFunctionName(func.getName()); 154c46b0feaSRiver Riddle auto funcCst = module->getOrInsertFunction(newName, newType); 155c46b0feaSRiver Riddle llvm::Function *interfaceFunc = 156c46b0feaSRiver Riddle llvm::cast<llvm::Function>(funcCst.getCallee()); 1575a440378SAlex Zinenko interfaceFunctions.insert(interfaceFunc); 1585a440378SAlex Zinenko 1595a440378SAlex Zinenko // Extract the arguments from the type-erased argument list and cast them to 1605a440378SAlex Zinenko // the proper types. 1615a440378SAlex Zinenko auto bb = llvm::BasicBlock::Create(ctx); 1625a440378SAlex Zinenko bb->insertInto(interfaceFunc); 1635a440378SAlex Zinenko builder.SetInsertPoint(bb); 1645a440378SAlex Zinenko llvm::Value *argList = interfaceFunc->arg_begin(); 1655a440378SAlex Zinenko llvm::SmallVector<llvm::Value *, 8> args; 1665a440378SAlex Zinenko args.reserve(llvm::size(func.args())); 1675a440378SAlex Zinenko for (auto &indexedArg : llvm::enumerate(func.args())) { 1685a440378SAlex Zinenko llvm::Value *argIndex = llvm::Constant::getIntegerValue( 1695a440378SAlex Zinenko builder.getInt64Ty(), llvm::APInt(64, indexedArg.index())); 1705a440378SAlex Zinenko llvm::Value *argPtrPtr = builder.CreateGEP(argList, argIndex); 1715a440378SAlex Zinenko llvm::Value *argPtr = builder.CreateLoad(argPtrPtr); 1725a440378SAlex Zinenko argPtr = builder.CreateBitCast( 1735a440378SAlex Zinenko argPtr, indexedArg.value().getType()->getPointerTo()); 1745a440378SAlex Zinenko llvm::Value *arg = builder.CreateLoad(argPtr); 1755a440378SAlex Zinenko args.push_back(arg); 1765a440378SAlex Zinenko } 1775a440378SAlex Zinenko 1785a440378SAlex Zinenko // Call the implementation function with the extracted arguments. 1795a440378SAlex Zinenko llvm::Value *result = builder.CreateCall(&func, args); 1805a440378SAlex Zinenko 1815a440378SAlex Zinenko // Assuming the result is one value, potentially of type `void`. 1825a440378SAlex Zinenko if (!result->getType()->isVoidTy()) { 1835a440378SAlex Zinenko llvm::Value *retIndex = llvm::Constant::getIntegerValue( 1845a440378SAlex Zinenko builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args()))); 1855a440378SAlex Zinenko llvm::Value *retPtrPtr = builder.CreateGEP(argList, retIndex); 1865a440378SAlex Zinenko llvm::Value *retPtr = builder.CreateLoad(retPtrPtr); 1875a440378SAlex Zinenko retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo()); 1885a440378SAlex Zinenko builder.CreateStore(result, retPtr); 1895a440378SAlex Zinenko } 1905a440378SAlex Zinenko 1915a440378SAlex Zinenko // The interface function returns void. 1925a440378SAlex Zinenko builder.CreateRetVoid(); 1935a440378SAlex Zinenko } 1945a440378SAlex Zinenko } 1955a440378SAlex Zinenko 19606e81010SJacques Pienaar ExecutionEngine::ExecutionEngine(bool enableObjectCache) 19706e81010SJacques Pienaar : cache(enableObjectCache ? nullptr : new SimpleObjectCache()) {} 19806e81010SJacques Pienaar 19906e81010SJacques Pienaar Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create( 20006e81010SJacques Pienaar ModuleOp m, std::function<Error(llvm::Module *)> transformer, 201713ab0ddSUday Bondhugula Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel, 20206e81010SJacques Pienaar ArrayRef<StringRef> sharedLibPaths, bool enableObjectCache) { 20306e81010SJacques Pienaar auto engine = std::make_unique<ExecutionEngine>(enableObjectCache); 2045a440378SAlex Zinenko 205fe3594f7SNicolas Vasilache std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext); 206206e55ccSRiver Riddle auto llvmModule = translateModuleToLLVMIR(m); 2075a440378SAlex Zinenko if (!llvmModule) 2085a440378SAlex Zinenko return make_string_error("could not convert to LLVM IR"); 2095a440378SAlex Zinenko // FIXME: the triple should be passed to the translation or dialect conversion 2105a440378SAlex Zinenko // instead of this. Currently, the LLVM module created above has no triple 2115a440378SAlex Zinenko // associated with it. 2125a440378SAlex Zinenko setupTargetTriple(llvmModule.get()); 2135a440378SAlex Zinenko packFunctionArguments(llvmModule.get()); 2145a440378SAlex Zinenko 215fe3594f7SNicolas Vasilache // Clone module in a new LLVMContext since translateModuleToLLVMIR buries 216fe3594f7SNicolas Vasilache // ownership too deeply. 217fe3594f7SNicolas Vasilache // TODO(zinenko): Reevaluate model of ownership of LLVMContext in LLVMDialect. 218fe3594f7SNicolas Vasilache SmallVector<char, 1> buffer; 219fe3594f7SNicolas Vasilache { 220fe3594f7SNicolas Vasilache llvm::raw_svector_ostream os(buffer); 221fe3594f7SNicolas Vasilache WriteBitcodeToFile(*llvmModule, os); 222fe3594f7SNicolas Vasilache } 223fe3594f7SNicolas Vasilache llvm::MemoryBufferRef bufferRef(llvm::StringRef(buffer.data(), buffer.size()), 224fe3594f7SNicolas Vasilache "cloned module buffer"); 225fe3594f7SNicolas Vasilache auto expectedModule = parseBitcodeFile(bufferRef, *ctx); 226fe3594f7SNicolas Vasilache if (!expectedModule) 227fe3594f7SNicolas Vasilache return expectedModule.takeError(); 228fe3594f7SNicolas Vasilache std::unique_ptr<Module> deserModule = std::move(*expectedModule); 229fe3594f7SNicolas Vasilache 230fe3594f7SNicolas Vasilache // Callback to create the object layer with symbol resolution to current 231fe3594f7SNicolas Vasilache // process and dynamically linked libraries. 232fe3594f7SNicolas Vasilache auto objectLinkingLayerCreator = [&](ExecutionSession &session, 233fe3594f7SNicolas Vasilache const Triple &TT) { 234fe3594f7SNicolas Vasilache auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>( 235fe3594f7SNicolas Vasilache session, []() { return std::make_unique<SectionMemoryManager>(); }); 236fe3594f7SNicolas Vasilache auto dataLayout = deserModule->getDataLayout(); 237fe3594f7SNicolas Vasilache 238fe3594f7SNicolas Vasilache // Resolve symbols that are statically linked in the current process. 239fe3594f7SNicolas Vasilache session.getMainJITDylib().addGenerator( 240fe3594f7SNicolas Vasilache cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess( 241fe3594f7SNicolas Vasilache dataLayout.getGlobalPrefix()))); 242fe3594f7SNicolas Vasilache 243fe3594f7SNicolas Vasilache // Resolve symbols from shared libraries. 244fe3594f7SNicolas Vasilache for (auto libPath : sharedLibPaths) { 245fe3594f7SNicolas Vasilache auto mb = llvm::MemoryBuffer::getFile(libPath); 246fe3594f7SNicolas Vasilache if (!mb) { 247fe3594f7SNicolas Vasilache errs() << "Fail to create MemoryBuffer for: " << libPath << "\n"; 248fe3594f7SNicolas Vasilache continue; 249fe3594f7SNicolas Vasilache } 250fe3594f7SNicolas Vasilache auto &JD = session.createJITDylib(libPath); 251fe3594f7SNicolas Vasilache auto loaded = DynamicLibrarySearchGenerator::Load( 252fe3594f7SNicolas Vasilache libPath.data(), dataLayout.getGlobalPrefix()); 253fe3594f7SNicolas Vasilache if (!loaded) { 254fe3594f7SNicolas Vasilache errs() << "Could not load: " << libPath << "\n"; 255fe3594f7SNicolas Vasilache continue; 256fe3594f7SNicolas Vasilache } 257fe3594f7SNicolas Vasilache JD.addGenerator(std::move(*loaded)); 258fe3594f7SNicolas Vasilache cantFail(objectLayer->add(JD, std::move(mb.get()))); 259fe3594f7SNicolas Vasilache } 260fe3594f7SNicolas Vasilache 261fe3594f7SNicolas Vasilache return objectLayer; 262fe3594f7SNicolas Vasilache }; 263fe3594f7SNicolas Vasilache 264fe3594f7SNicolas Vasilache // Callback to inspect the cache and recompile on demand. This follows Lang's 265fe3594f7SNicolas Vasilache // LLJITWithObjectCache example. 266fe3594f7SNicolas Vasilache auto compileFunctionCreator = [&](JITTargetMachineBuilder JTMB) 267fe3594f7SNicolas Vasilache -> Expected<IRCompileLayer::CompileFunction> { 268713ab0ddSUday Bondhugula if (jitCodeGenOptLevel) 269713ab0ddSUday Bondhugula JTMB.setCodeGenOptLevel(jitCodeGenOptLevel.getValue()); 270fe3594f7SNicolas Vasilache auto TM = JTMB.createTargetMachine(); 271fe3594f7SNicolas Vasilache if (!TM) 272fe3594f7SNicolas Vasilache return TM.takeError(); 273fe3594f7SNicolas Vasilache return IRCompileLayer::CompileFunction( 274fe3594f7SNicolas Vasilache TMOwningSimpleCompiler(std::move(*TM), engine->cache.get())); 275fe3594f7SNicolas Vasilache }; 276fe3594f7SNicolas Vasilache 277fe3594f7SNicolas Vasilache // Create the LLJIT by calling the LLJITBuilder with 2 callbacks. 278fe3594f7SNicolas Vasilache auto jit = 279fe3594f7SNicolas Vasilache cantFail(llvm::orc::LLJITBuilder() 280fe3594f7SNicolas Vasilache .setCompileFunctionCreator(compileFunctionCreator) 281fe3594f7SNicolas Vasilache .setObjectLinkingLayerCreator(objectLinkingLayerCreator) 282fe3594f7SNicolas Vasilache .create()); 283fe3594f7SNicolas Vasilache 284fe3594f7SNicolas Vasilache // Add a ThreadSafemodule to the engine and return. 285fe3594f7SNicolas Vasilache ThreadSafeModule tsm(std::move(deserModule), std::move(ctx)); 286cf26e5faSNicolas Vasilache if (transformer) 287cf26e5faSNicolas Vasilache cantFail(tsm.withModuleDo( 288cf26e5faSNicolas Vasilache [&](llvm::Module &module) { return transformer(&module); })); 289fe3594f7SNicolas Vasilache cantFail(jit->addIRModule(std::move(tsm))); 290fe3594f7SNicolas Vasilache engine->jit = std::move(jit); 2915a440378SAlex Zinenko 292e7111fd6SJacques Pienaar return std::move(engine); 2935a440378SAlex Zinenko } 2945a440378SAlex Zinenko 2955a440378SAlex Zinenko Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const { 2965a440378SAlex Zinenko auto expectedSymbol = jit->lookup(makePackedFunctionName(name)); 2975a440378SAlex Zinenko if (!expectedSymbol) 2985a440378SAlex Zinenko return expectedSymbol.takeError(); 2995a440378SAlex Zinenko auto rawFPtr = expectedSymbol->getAddress(); 3005a440378SAlex Zinenko auto fptr = reinterpret_cast<void (*)(void **)>(rawFPtr); 3015a440378SAlex Zinenko if (!fptr) 3025a440378SAlex Zinenko return make_string_error("looked up function is null"); 3035a440378SAlex Zinenko return fptr; 3045a440378SAlex Zinenko } 305629f5b7fSNicolas Vasilache 306fe3594f7SNicolas Vasilache Error ExecutionEngine::invoke(StringRef name, MutableArrayRef<void *> args) { 307629f5b7fSNicolas Vasilache auto expectedFPtr = lookup(name); 308629f5b7fSNicolas Vasilache if (!expectedFPtr) 309629f5b7fSNicolas Vasilache return expectedFPtr.takeError(); 310629f5b7fSNicolas Vasilache auto fptr = *expectedFPtr; 311629f5b7fSNicolas Vasilache 312629f5b7fSNicolas Vasilache (*fptr)(args.data()); 313629f5b7fSNicolas Vasilache 314fe3594f7SNicolas Vasilache return Error::success(); 315629f5b7fSNicolas Vasilache } 316fe3594f7SNicolas Vasilache } // end namespace mlir 317