1 //===- ExecutionEngine.cpp - C API for MLIR JIT ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir-c/ExecutionEngine.h"
10 #include "mlir/CAPI/ExecutionEngine.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Support.h"
13 #include "mlir/ExecutionEngine/OptUtils.h"
14 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
15 #include "llvm/ExecutionEngine/Orc/Mangling.h"
16 #include "llvm/Support/TargetSelect.h"
17 
18 using namespace mlir;
19 
20 extern "C" MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op,
21                                                          int optLevel) {
22   static bool initOnce = [] {
23     llvm::InitializeNativeTarget();
24     llvm::InitializeNativeTargetAsmPrinter();
25     return true;
26   }();
27   (void)initOnce;
28 
29   mlir::registerLLVMDialectTranslation(*unwrap(op)->getContext());
30 
31   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
32   if (!tmBuilderOrError) {
33     llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
34     return MlirExecutionEngine{nullptr};
35   }
36   auto tmOrError = tmBuilderOrError->createTargetMachine();
37   if (!tmOrError) {
38     llvm::errs() << "Failed to create a TargetMachine for the host\n";
39     return MlirExecutionEngine{nullptr};
40   }
41 
42   // Create a transformer to run all LLVM optimization passes at the
43   // specified optimization level.
44   auto llvmOptLevel = static_cast<llvm::CodeGenOpt::Level>(optLevel);
45   auto transformer = mlir::makeLLVMPassesTransformer(
46       /*passes=*/{}, llvmOptLevel, /*targetMachine=*/tmOrError->get());
47   auto jitOrError = ExecutionEngine::create(
48       unwrap(op), /*llvmModuleBuilder=*/{}, transformer, llvmOptLevel);
49   if (!jitOrError) {
50     consumeError(jitOrError.takeError());
51     return MlirExecutionEngine{nullptr};
52   }
53   return wrap(jitOrError->release());
54 }
55 
56 extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
57   delete (unwrap(jit));
58 }
59 
60 extern "C" MlirLogicalResult
61 mlirExecutionEngineInvokePacked(MlirExecutionEngine jit, MlirStringRef name,
62                                 void **arguments) {
63   const std::string ifaceName = ("_mlir_ciface_" + unwrap(name)).str();
64   llvm::Error error = unwrap(jit)->invokePacked(
65       ifaceName, MutableArrayRef<void *>{arguments, (size_t)0});
66   if (error)
67     return wrap(failure());
68   return wrap(success());
69 }
70 
71 extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
72                                            MlirStringRef name) {
73   auto expectedFPtr = unwrap(jit)->lookup(unwrap(name));
74   if (!expectedFPtr)
75     return nullptr;
76   return reinterpret_cast<void *>(*expectedFPtr);
77 }
78 
79 extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
80                                                   MlirStringRef name,
81                                                   void *sym) {
82   unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
83     llvm::orc::SymbolMap symbolMap;
84     symbolMap[interner(unwrap(name))] =
85         llvm::JITEvaluatedSymbol::fromPointer(sym);
86     return symbolMap;
87   });
88 }
89 
90 extern "C" void mlirExecutionEngineDumpToObjectFile(MlirExecutionEngine jit,
91                                                     MlirStringRef name) {
92   unwrap(jit)->dumpToObjectFile(unwrap(name));
93 }
94