15a440378SAlex Zinenko //===- ExecutionEngine.cpp - MLIR Execution engine and utils --------------===//
25a440378SAlex Zinenko //
35a440378SAlex Zinenko // Copyright 2019 The MLIR Authors.
45a440378SAlex Zinenko //
55a440378SAlex Zinenko // Licensed under the Apache License, Version 2.0 (the "License");
65a440378SAlex Zinenko // you may not use this file except in compliance with the License.
75a440378SAlex Zinenko // You may obtain a copy of the License at
85a440378SAlex Zinenko //
95a440378SAlex Zinenko //   http://www.apache.org/licenses/LICENSE-2.0
105a440378SAlex Zinenko //
115a440378SAlex Zinenko // Unless required by applicable law or agreed to in writing, software
125a440378SAlex Zinenko // distributed under the License is distributed on an "AS IS" BASIS,
135a440378SAlex Zinenko // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
145a440378SAlex Zinenko // See the License for the specific language governing permissions and
155a440378SAlex Zinenko // limitations under the License.
165a440378SAlex Zinenko // =============================================================================
175a440378SAlex Zinenko //
185a440378SAlex Zinenko // This file implements the execution engine for MLIR modules based on LLVM Orc
195a440378SAlex Zinenko // JIT engine.
205a440378SAlex Zinenko //
215a440378SAlex Zinenko //===----------------------------------------------------------------------===//
225a440378SAlex Zinenko #include "mlir/ExecutionEngine/ExecutionEngine.h"
235a440378SAlex Zinenko #include "mlir/IR/Function.h"
245a440378SAlex Zinenko #include "mlir/IR/Module.h"
2506e81010SJacques Pienaar #include "mlir/Support/FileUtilities.h"
265a440378SAlex Zinenko #include "mlir/Target/LLVMIR.h"
275a440378SAlex Zinenko 
28fe3594f7SNicolas Vasilache #include "llvm/Bitcode/BitcodeReader.h"
29fe3594f7SNicolas Vasilache #include "llvm/Bitcode/BitcodeWriter.h"
30fe3594f7SNicolas Vasilache #include "llvm/ExecutionEngine/ObjectCache.h"
315a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
325a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/ExecutionUtils.h"
335a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
348093f17aSAlex Zinenko #include "llvm/ExecutionEngine/Orc/IRTransformLayer.h"
355a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
365a440378SAlex Zinenko #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h"
375a440378SAlex Zinenko #include "llvm/ExecutionEngine/SectionMemoryManager.h"
385a440378SAlex Zinenko #include "llvm/IR/IRBuilder.h"
395a440378SAlex Zinenko #include "llvm/Support/Error.h"
405a440378SAlex Zinenko #include "llvm/Support/TargetRegistry.h"
4106e81010SJacques Pienaar #include "llvm/Support/ToolOutputFile.h"
425a440378SAlex Zinenko 
435a440378SAlex Zinenko using namespace mlir;
44fe3594f7SNicolas Vasilache using llvm::dbgs;
455a440378SAlex Zinenko using llvm::Error;
46fe3594f7SNicolas Vasilache using llvm::errs;
475a440378SAlex Zinenko using llvm::Expected;
48fe3594f7SNicolas Vasilache using llvm::LLVMContext;
49fe3594f7SNicolas Vasilache using llvm::MemoryBuffer;
50fe3594f7SNicolas Vasilache using llvm::MemoryBufferRef;
51fe3594f7SNicolas Vasilache using llvm::Module;
52fe3594f7SNicolas Vasilache using llvm::SectionMemoryManager;
53fe3594f7SNicolas Vasilache using llvm::StringError;
54fe3594f7SNicolas Vasilache using llvm::Triple;
55fe3594f7SNicolas Vasilache using llvm::orc::DynamicLibrarySearchGenerator;
56fe3594f7SNicolas Vasilache using llvm::orc::ExecutionSession;
57fe3594f7SNicolas Vasilache using llvm::orc::IRCompileLayer;
58fe3594f7SNicolas Vasilache using llvm::orc::JITTargetMachineBuilder;
59fe3594f7SNicolas Vasilache using llvm::orc::RTDyldObjectLinkingLayer;
60fe3594f7SNicolas Vasilache using llvm::orc::ThreadSafeModule;
61fe3594f7SNicolas Vasilache using llvm::orc::TMOwningSimpleCompiler;
626aa5cc8bSNicolas Vasilache 
635a440378SAlex Zinenko // Wrap a string into an llvm::StringError.
645a440378SAlex Zinenko static inline Error make_string_error(const llvm::Twine &message) {
65fe3594f7SNicolas Vasilache   return llvm::make_error<StringError>(message.str(),
665a440378SAlex Zinenko                                        llvm::inconvertibleErrorCode());
675a440378SAlex Zinenko }
685a440378SAlex Zinenko 
69fe3594f7SNicolas Vasilache namespace mlir {
70fe3594f7SNicolas Vasilache 
71fe3594f7SNicolas Vasilache void SimpleObjectCache::notifyObjectCompiled(const Module *M,
72fe3594f7SNicolas Vasilache                                              MemoryBufferRef ObjBuffer) {
7306e81010SJacques Pienaar   cachedObjects[M->getModuleIdentifier()] = MemoryBuffer::getMemBufferCopy(
74fe3594f7SNicolas Vasilache       ObjBuffer.getBuffer(), ObjBuffer.getBufferIdentifier());
75fe3594f7SNicolas Vasilache }
76fe3594f7SNicolas Vasilache 
77fe3594f7SNicolas Vasilache std::unique_ptr<MemoryBuffer> SimpleObjectCache::getObject(const Module *M) {
7806e81010SJacques Pienaar   auto I = cachedObjects.find(M->getModuleIdentifier());
7906e81010SJacques Pienaar   if (I == cachedObjects.end()) {
80fe3594f7SNicolas Vasilache     dbgs() << "No object for " << M->getModuleIdentifier()
81fe3594f7SNicolas Vasilache            << " in cache. Compiling.\n";
82fe3594f7SNicolas Vasilache     return nullptr;
83fe3594f7SNicolas Vasilache   }
84fe3594f7SNicolas Vasilache   dbgs() << "Object for " << M->getModuleIdentifier()
85fe3594f7SNicolas Vasilache          << " loaded from cache.\n";
86fe3594f7SNicolas Vasilache   return MemoryBuffer::getMemBuffer(I->second->getMemBufferRef());
87fe3594f7SNicolas Vasilache }
88fe3594f7SNicolas Vasilache 
8906e81010SJacques Pienaar void SimpleObjectCache::dumpToObjectFile(llvm::StringRef outputFilename) {
9006e81010SJacques Pienaar   // Set up the output file.
9106e81010SJacques Pienaar   std::string errorMessage;
9206e81010SJacques Pienaar   auto file = openOutputFile(outputFilename, &errorMessage);
9306e81010SJacques Pienaar   if (!file) {
9406e81010SJacques Pienaar     llvm::errs() << errorMessage << "\n";
9506e81010SJacques Pienaar     return;
9606e81010SJacques Pienaar   }
9706e81010SJacques Pienaar 
9806e81010SJacques Pienaar   // Dump the object generated for a single module to the output file.
9906e81010SJacques Pienaar   assert(cachedObjects.size() == 1 && "Expected only one object entry.");
10006e81010SJacques Pienaar   auto &cachedObject = cachedObjects.begin()->second;
10106e81010SJacques Pienaar   file->os() << cachedObject->getBuffer();
10206e81010SJacques Pienaar   file->keep();
10306e81010SJacques Pienaar }
10406e81010SJacques Pienaar 
10506e81010SJacques Pienaar void ExecutionEngine::dumpToObjectFile(llvm::StringRef filename) {
10606e81010SJacques Pienaar   cache->dumpToObjectFile(filename);
10706e81010SJacques Pienaar }
10806e81010SJacques Pienaar 
1095a440378SAlex Zinenko // Setup LLVM target triple from the current machine.
110fe3594f7SNicolas Vasilache bool ExecutionEngine::setupTargetTriple(Module *llvmModule) {
1115a440378SAlex Zinenko   // Setup the machine properties from the current architecture.
1125a440378SAlex Zinenko   auto targetTriple = llvm::sys::getDefaultTargetTriple();
1135a440378SAlex Zinenko   std::string errorMessage;
1145a440378SAlex Zinenko   auto target = llvm::TargetRegistry::lookupTarget(targetTriple, errorMessage);
1155a440378SAlex Zinenko   if (!target) {
116fe3594f7SNicolas Vasilache     errs() << "NO target: " << errorMessage << "\n";
1175a440378SAlex Zinenko     return true;
1185a440378SAlex Zinenko   }
1195a440378SAlex Zinenko   auto machine =
1205a440378SAlex Zinenko       target->createTargetMachine(targetTriple, "generic", "", {}, {});
1215a440378SAlex Zinenko   llvmModule->setDataLayout(machine->createDataLayout());
1225a440378SAlex Zinenko   llvmModule->setTargetTriple(targetTriple);
1235a440378SAlex Zinenko   return false;
1245a440378SAlex Zinenko }
1255a440378SAlex Zinenko 
1265a440378SAlex Zinenko static std::string makePackedFunctionName(StringRef name) {
1275a440378SAlex Zinenko   return "_mlir_" + name.str();
1285a440378SAlex Zinenko }
1295a440378SAlex Zinenko 
1305a440378SAlex Zinenko // For each function in the LLVM module, define an interface function that wraps
1315a440378SAlex Zinenko // all the arguments of the original function and all its results into an i8**
1325a440378SAlex Zinenko // pointer to provide a unified invocation interface.
133fe3594f7SNicolas Vasilache void packFunctionArguments(Module *module) {
1345a440378SAlex Zinenko   auto &ctx = module->getContext();
1355a440378SAlex Zinenko   llvm::IRBuilder<> builder(ctx);
1365a440378SAlex Zinenko   llvm::DenseSet<llvm::Function *> interfaceFunctions;
1375a440378SAlex Zinenko   for (auto &func : module->getFunctionList()) {
1385a440378SAlex Zinenko     if (func.isDeclaration()) {
1395a440378SAlex Zinenko       continue;
1405a440378SAlex Zinenko     }
1415a440378SAlex Zinenko     if (interfaceFunctions.count(&func)) {
1425a440378SAlex Zinenko       continue;
1435a440378SAlex Zinenko     }
1445a440378SAlex Zinenko 
1455a440378SAlex Zinenko     // Given a function `foo(<...>)`, define the interface function
1465a440378SAlex Zinenko     // `mlir_foo(i8**)`.
1475a440378SAlex Zinenko     auto newType = llvm::FunctionType::get(
1485a440378SAlex Zinenko         builder.getVoidTy(), builder.getInt8PtrTy()->getPointerTo(),
1495a440378SAlex Zinenko         /*isVarArg=*/false);
1505a440378SAlex Zinenko     auto newName = makePackedFunctionName(func.getName());
151c46b0feaSRiver Riddle     auto funcCst = module->getOrInsertFunction(newName, newType);
152c46b0feaSRiver Riddle     llvm::Function *interfaceFunc =
153c46b0feaSRiver Riddle         llvm::cast<llvm::Function>(funcCst.getCallee());
1545a440378SAlex Zinenko     interfaceFunctions.insert(interfaceFunc);
1555a440378SAlex Zinenko 
1565a440378SAlex Zinenko     // Extract the arguments from the type-erased argument list and cast them to
1575a440378SAlex Zinenko     // the proper types.
1585a440378SAlex Zinenko     auto bb = llvm::BasicBlock::Create(ctx);
1595a440378SAlex Zinenko     bb->insertInto(interfaceFunc);
1605a440378SAlex Zinenko     builder.SetInsertPoint(bb);
1615a440378SAlex Zinenko     llvm::Value *argList = interfaceFunc->arg_begin();
1625a440378SAlex Zinenko     llvm::SmallVector<llvm::Value *, 8> args;
1635a440378SAlex Zinenko     args.reserve(llvm::size(func.args()));
1645a440378SAlex Zinenko     for (auto &indexedArg : llvm::enumerate(func.args())) {
1655a440378SAlex Zinenko       llvm::Value *argIndex = llvm::Constant::getIntegerValue(
1665a440378SAlex Zinenko           builder.getInt64Ty(), llvm::APInt(64, indexedArg.index()));
1675a440378SAlex Zinenko       llvm::Value *argPtrPtr = builder.CreateGEP(argList, argIndex);
1685a440378SAlex Zinenko       llvm::Value *argPtr = builder.CreateLoad(argPtrPtr);
1695a440378SAlex Zinenko       argPtr = builder.CreateBitCast(
1705a440378SAlex Zinenko           argPtr, indexedArg.value().getType()->getPointerTo());
1715a440378SAlex Zinenko       llvm::Value *arg = builder.CreateLoad(argPtr);
1725a440378SAlex Zinenko       args.push_back(arg);
1735a440378SAlex Zinenko     }
1745a440378SAlex Zinenko 
1755a440378SAlex Zinenko     // Call the implementation function with the extracted arguments.
1765a440378SAlex Zinenko     llvm::Value *result = builder.CreateCall(&func, args);
1775a440378SAlex Zinenko 
1785a440378SAlex Zinenko     // Assuming the result is one value, potentially of type `void`.
1795a440378SAlex Zinenko     if (!result->getType()->isVoidTy()) {
1805a440378SAlex Zinenko       llvm::Value *retIndex = llvm::Constant::getIntegerValue(
1815a440378SAlex Zinenko           builder.getInt64Ty(), llvm::APInt(64, llvm::size(func.args())));
1825a440378SAlex Zinenko       llvm::Value *retPtrPtr = builder.CreateGEP(argList, retIndex);
1835a440378SAlex Zinenko       llvm::Value *retPtr = builder.CreateLoad(retPtrPtr);
1845a440378SAlex Zinenko       retPtr = builder.CreateBitCast(retPtr, result->getType()->getPointerTo());
1855a440378SAlex Zinenko       builder.CreateStore(result, retPtr);
1865a440378SAlex Zinenko     }
1875a440378SAlex Zinenko 
1885a440378SAlex Zinenko     // The interface function returns void.
1895a440378SAlex Zinenko     builder.CreateRetVoid();
1905a440378SAlex Zinenko   }
1915a440378SAlex Zinenko }
1925a440378SAlex Zinenko 
19306e81010SJacques Pienaar ExecutionEngine::ExecutionEngine(bool enableObjectCache)
19406e81010SJacques Pienaar     : cache(enableObjectCache ? nullptr : new SimpleObjectCache()) {}
19506e81010SJacques Pienaar 
19606e81010SJacques Pienaar Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
19706e81010SJacques Pienaar     ModuleOp m, std::function<Error(llvm::Module *)> transformer,
19806e81010SJacques Pienaar     ArrayRef<StringRef> sharedLibPaths, bool enableObjectCache) {
19906e81010SJacques Pienaar   auto engine = std::make_unique<ExecutionEngine>(enableObjectCache);
2005a440378SAlex Zinenko 
201fe3594f7SNicolas Vasilache   std::unique_ptr<llvm::LLVMContext> ctx(new llvm::LLVMContext);
202206e55ccSRiver Riddle   auto llvmModule = translateModuleToLLVMIR(m);
2035a440378SAlex Zinenko   if (!llvmModule)
2045a440378SAlex Zinenko     return make_string_error("could not convert to LLVM IR");
2055a440378SAlex Zinenko   // FIXME: the triple should be passed to the translation or dialect conversion
2065a440378SAlex Zinenko   // instead of this.  Currently, the LLVM module created above has no triple
2075a440378SAlex Zinenko   // associated with it.
2085a440378SAlex Zinenko   setupTargetTriple(llvmModule.get());
2095a440378SAlex Zinenko   packFunctionArguments(llvmModule.get());
2105a440378SAlex Zinenko 
211fe3594f7SNicolas Vasilache   // Clone module in a new LLVMContext since translateModuleToLLVMIR buries
212fe3594f7SNicolas Vasilache   // ownership too deeply.
213fe3594f7SNicolas Vasilache   // TODO(zinenko): Reevaluate model of ownership of LLVMContext in LLVMDialect.
214fe3594f7SNicolas Vasilache   SmallVector<char, 1> buffer;
215fe3594f7SNicolas Vasilache   {
216fe3594f7SNicolas Vasilache     llvm::raw_svector_ostream os(buffer);
217fe3594f7SNicolas Vasilache     WriteBitcodeToFile(*llvmModule, os);
218fe3594f7SNicolas Vasilache   }
219fe3594f7SNicolas Vasilache   llvm::MemoryBufferRef bufferRef(llvm::StringRef(buffer.data(), buffer.size()),
220fe3594f7SNicolas Vasilache                                   "cloned module buffer");
221fe3594f7SNicolas Vasilache   auto expectedModule = parseBitcodeFile(bufferRef, *ctx);
222fe3594f7SNicolas Vasilache   if (!expectedModule)
223fe3594f7SNicolas Vasilache     return expectedModule.takeError();
224fe3594f7SNicolas Vasilache   std::unique_ptr<Module> deserModule = std::move(*expectedModule);
225fe3594f7SNicolas Vasilache 
226fe3594f7SNicolas Vasilache   // Callback to create the object layer with symbol resolution to current
227fe3594f7SNicolas Vasilache   // process and dynamically linked libraries.
228fe3594f7SNicolas Vasilache   auto objectLinkingLayerCreator = [&](ExecutionSession &session,
229fe3594f7SNicolas Vasilache                                        const Triple &TT) {
230fe3594f7SNicolas Vasilache     auto objectLayer = std::make_unique<RTDyldObjectLinkingLayer>(
231fe3594f7SNicolas Vasilache         session, []() { return std::make_unique<SectionMemoryManager>(); });
232fe3594f7SNicolas Vasilache     auto dataLayout = deserModule->getDataLayout();
233fe3594f7SNicolas Vasilache 
234fe3594f7SNicolas Vasilache     // Resolve symbols that are statically linked in the current process.
235fe3594f7SNicolas Vasilache     session.getMainJITDylib().addGenerator(
236fe3594f7SNicolas Vasilache         cantFail(DynamicLibrarySearchGenerator::GetForCurrentProcess(
237fe3594f7SNicolas Vasilache             dataLayout.getGlobalPrefix())));
238fe3594f7SNicolas Vasilache 
239fe3594f7SNicolas Vasilache     // Resolve symbols from shared libraries.
240fe3594f7SNicolas Vasilache     for (auto libPath : sharedLibPaths) {
241fe3594f7SNicolas Vasilache       auto mb = llvm::MemoryBuffer::getFile(libPath);
242fe3594f7SNicolas Vasilache       if (!mb) {
243fe3594f7SNicolas Vasilache         errs() << "Fail to create MemoryBuffer for: " << libPath << "\n";
244fe3594f7SNicolas Vasilache         continue;
245fe3594f7SNicolas Vasilache       }
246fe3594f7SNicolas Vasilache       auto &JD = session.createJITDylib(libPath);
247fe3594f7SNicolas Vasilache       auto loaded = DynamicLibrarySearchGenerator::Load(
248fe3594f7SNicolas Vasilache           libPath.data(), dataLayout.getGlobalPrefix());
249fe3594f7SNicolas Vasilache       if (!loaded) {
250fe3594f7SNicolas Vasilache         errs() << "Could not load: " << libPath << "\n";
251fe3594f7SNicolas Vasilache         continue;
252fe3594f7SNicolas Vasilache       }
253fe3594f7SNicolas Vasilache       JD.addGenerator(std::move(*loaded));
254fe3594f7SNicolas Vasilache       cantFail(objectLayer->add(JD, std::move(mb.get())));
255fe3594f7SNicolas Vasilache     }
256fe3594f7SNicolas Vasilache 
257fe3594f7SNicolas Vasilache     return objectLayer;
258fe3594f7SNicolas Vasilache   };
259fe3594f7SNicolas Vasilache 
260fe3594f7SNicolas Vasilache   // Callback to inspect the cache and recompile on demand. This follows Lang's
261fe3594f7SNicolas Vasilache   // LLJITWithObjectCache example.
262fe3594f7SNicolas Vasilache   auto compileFunctionCreator = [&](JITTargetMachineBuilder JTMB)
263fe3594f7SNicolas Vasilache       -> Expected<IRCompileLayer::CompileFunction> {
264fe3594f7SNicolas Vasilache     auto TM = JTMB.createTargetMachine();
265fe3594f7SNicolas Vasilache     if (!TM)
266fe3594f7SNicolas Vasilache       return TM.takeError();
267fe3594f7SNicolas Vasilache     return IRCompileLayer::CompileFunction(
268fe3594f7SNicolas Vasilache         TMOwningSimpleCompiler(std::move(*TM), engine->cache.get()));
269fe3594f7SNicolas Vasilache   };
270fe3594f7SNicolas Vasilache 
271fe3594f7SNicolas Vasilache   // Create the LLJIT by calling the LLJITBuilder with 2 callbacks.
272fe3594f7SNicolas Vasilache   auto jit =
273fe3594f7SNicolas Vasilache       cantFail(llvm::orc::LLJITBuilder()
274fe3594f7SNicolas Vasilache                    .setCompileFunctionCreator(compileFunctionCreator)
275fe3594f7SNicolas Vasilache                    .setObjectLinkingLayerCreator(objectLinkingLayerCreator)
276fe3594f7SNicolas Vasilache                    .create());
277fe3594f7SNicolas Vasilache 
278fe3594f7SNicolas Vasilache   // Add a ThreadSafemodule to the engine and return.
279fe3594f7SNicolas Vasilache   ThreadSafeModule tsm(std::move(deserModule), std::move(ctx));
280*cf26e5faSNicolas Vasilache   if (transformer)
281*cf26e5faSNicolas Vasilache     cantFail(tsm.withModuleDo(
282*cf26e5faSNicolas Vasilache         [&](llvm::Module &module) { return transformer(&module); }));
283fe3594f7SNicolas Vasilache   cantFail(jit->addIRModule(std::move(tsm)));
284fe3594f7SNicolas Vasilache   engine->jit = std::move(jit);
2855a440378SAlex Zinenko 
286e7111fd6SJacques Pienaar   return std::move(engine);
2875a440378SAlex Zinenko }
2885a440378SAlex Zinenko 
2895a440378SAlex Zinenko Expected<void (*)(void **)> ExecutionEngine::lookup(StringRef name) const {
2905a440378SAlex Zinenko   auto expectedSymbol = jit->lookup(makePackedFunctionName(name));
2915a440378SAlex Zinenko   if (!expectedSymbol)
2925a440378SAlex Zinenko     return expectedSymbol.takeError();
2935a440378SAlex Zinenko   auto rawFPtr = expectedSymbol->getAddress();
2945a440378SAlex Zinenko   auto fptr = reinterpret_cast<void (*)(void **)>(rawFPtr);
2955a440378SAlex Zinenko   if (!fptr)
2965a440378SAlex Zinenko     return make_string_error("looked up function is null");
2975a440378SAlex Zinenko   return fptr;
2985a440378SAlex Zinenko }
299629f5b7fSNicolas Vasilache 
300fe3594f7SNicolas Vasilache Error ExecutionEngine::invoke(StringRef name, MutableArrayRef<void *> args) {
301629f5b7fSNicolas Vasilache   auto expectedFPtr = lookup(name);
302629f5b7fSNicolas Vasilache   if (!expectedFPtr)
303629f5b7fSNicolas Vasilache     return expectedFPtr.takeError();
304629f5b7fSNicolas Vasilache   auto fptr = *expectedFPtr;
305629f5b7fSNicolas Vasilache 
306629f5b7fSNicolas Vasilache   (*fptr)(args.data());
307629f5b7fSNicolas Vasilache 
308fe3594f7SNicolas Vasilache   return Error::success();
309629f5b7fSNicolas Vasilache }
310fe3594f7SNicolas Vasilache } // end namespace mlir
311