1 //===- ExecutionEngine.cpp - C API for MLIR JIT ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir-c/ExecutionEngine.h"
10 #include "mlir/CAPI/ExecutionEngine.h"
11 #include "mlir/CAPI/IR.h"
12 #include "mlir/CAPI/Support.h"
13 #include "mlir/ExecutionEngine/OptUtils.h"
14 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
15 #include "llvm/ExecutionEngine/Orc/Mangling.h"
16 #include "llvm/Support/TargetSelect.h"
17
18 using namespace mlir;
19
20 extern "C" MlirExecutionEngine
mlirExecutionEngineCreate(MlirModule op,int optLevel,int numPaths,const MlirStringRef * sharedLibPaths)21 mlirExecutionEngineCreate(MlirModule op, int optLevel, int numPaths,
22 const MlirStringRef *sharedLibPaths) {
23 static bool initOnce = [] {
24 llvm::InitializeNativeTarget();
25 llvm::InitializeNativeTargetAsmParser(); // needed for inline_asm
26 llvm::InitializeNativeTargetAsmPrinter();
27 return true;
28 }();
29 (void)initOnce;
30
31 mlir::registerLLVMDialectTranslation(*unwrap(op)->getContext());
32
33 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
34 if (!tmBuilderOrError) {
35 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
36 return MlirExecutionEngine{nullptr};
37 }
38 auto tmOrError = tmBuilderOrError->createTargetMachine();
39 if (!tmOrError) {
40 llvm::errs() << "Failed to create a TargetMachine for the host\n";
41 return MlirExecutionEngine{nullptr};
42 }
43
44 SmallVector<StringRef> libPaths;
45 for (unsigned i = 0; i < static_cast<unsigned>(numPaths); ++i)
46 libPaths.push_back(sharedLibPaths[i].data);
47
48 // Create a transformer to run all LLVM optimization passes at the
49 // specified optimization level.
50 auto llvmOptLevel = static_cast<llvm::CodeGenOpt::Level>(optLevel);
51 auto transformer = mlir::makeOptimizingTransformer(
52 llvmOptLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
53 ExecutionEngineOptions jitOptions;
54 jitOptions.transformer = transformer;
55 jitOptions.jitCodeGenOptLevel = llvmOptLevel;
56 jitOptions.sharedLibPaths = libPaths;
57 auto jitOrError = ExecutionEngine::create(unwrap(op), jitOptions);
58 if (!jitOrError) {
59 consumeError(jitOrError.takeError());
60 return MlirExecutionEngine{nullptr};
61 }
62 return wrap(jitOrError->release());
63 }
64
mlirExecutionEngineDestroy(MlirExecutionEngine jit)65 extern "C" void mlirExecutionEngineDestroy(MlirExecutionEngine jit) {
66 delete (unwrap(jit));
67 }
68
69 extern "C" MlirLogicalResult
mlirExecutionEngineInvokePacked(MlirExecutionEngine jit,MlirStringRef name,void ** arguments)70 mlirExecutionEngineInvokePacked(MlirExecutionEngine jit, MlirStringRef name,
71 void **arguments) {
72 const std::string ifaceName = ("_mlir_ciface_" + unwrap(name)).str();
73 llvm::Error error = unwrap(jit)->invokePacked(
74 ifaceName, MutableArrayRef<void *>{arguments, (size_t)0});
75 if (error)
76 return wrap(failure());
77 return wrap(success());
78 }
79
mlirExecutionEngineLookupPacked(MlirExecutionEngine jit,MlirStringRef name)80 extern "C" void *mlirExecutionEngineLookupPacked(MlirExecutionEngine jit,
81 MlirStringRef name) {
82 auto expectedFPtr = unwrap(jit)->lookupPacked(unwrap(name));
83 if (!expectedFPtr)
84 return nullptr;
85 return reinterpret_cast<void *>(*expectedFPtr);
86 }
87
mlirExecutionEngineLookup(MlirExecutionEngine jit,MlirStringRef name)88 extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
89 MlirStringRef name) {
90 auto expectedFPtr = unwrap(jit)->lookup(unwrap(name));
91 if (!expectedFPtr)
92 return nullptr;
93 return reinterpret_cast<void *>(*expectedFPtr);
94 }
95
mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,MlirStringRef name,void * sym)96 extern "C" void mlirExecutionEngineRegisterSymbol(MlirExecutionEngine jit,
97 MlirStringRef name,
98 void *sym) {
99 unwrap(jit)->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
100 llvm::orc::SymbolMap symbolMap;
101 symbolMap[interner(unwrap(name))] =
102 llvm::JITEvaluatedSymbol::fromPointer(sym);
103 return symbolMap;
104 });
105 }
106
mlirExecutionEngineDumpToObjectFile(MlirExecutionEngine jit,MlirStringRef name)107 extern "C" void mlirExecutionEngineDumpToObjectFile(MlirExecutionEngine jit,
108 MlirStringRef name) {
109 unwrap(jit)->dumpToObjectFile(unwrap(name));
110 }
111