1 //===--- OrcCAPITest.cpp - Unit tests for the OrcJIT v2 C API ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm-c/Core.h"
10 #include "llvm-c/Error.h"
11 #include "llvm-c/LLJIT.h"
12 #include "llvm-c/Orc.h"
13 #include "gtest/gtest.h"
14 
15 #include "llvm/ADT/Triple.h"
16 #include "llvm/ExecutionEngine/Orc/CompileUtils.h"
17 #include "llvm/IR/LLVMContext.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/IRReader/IRReader.h"
20 #include "llvm/Support/Error.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "llvm/Support/SourceMgr.h"
23 #include "llvm/Testing/Support/Error.h"
24 #include <string>
25 
26 using namespace llvm;
27 using namespace llvm::orc;
28 
29 DEFINE_SIMPLE_CONVERSION_FUNCTIONS(ThreadSafeModule, LLVMOrcThreadSafeModuleRef)
30 
31 // OrcCAPITestBase contains several helper methods and pointers for unit tests
32 // written for the LLVM-C API. It provides the following helpers:
33 //
34 // 1. Jit: an LLVMOrcLLJIT instance which is freed upon test exit
35 // 2. ExecutionSession: the LLVMOrcExecutionSession for the JIT
36 // 3. MainDylib: the main JITDylib for the LLJIT instance
37 // 4. materializationUnitFn: function pointer to an empty function, used for
38 //                           materialization unit testing
39 // 5. definitionGeneratorFn: function pointer for a basic
40 //                           LLVMOrcCAPIDefinitionGeneratorTryToGenerateFunction
41 // 6. createTestModule: helper method for creating a basic thread-safe-module
42 class OrcCAPITestBase : public testing::Test {
43 protected:
44   LLVMOrcLLJITRef Jit = nullptr;
45   LLVMOrcExecutionSessionRef ExecutionSession = nullptr;
46   LLVMOrcJITDylibRef MainDylib = nullptr;
47 
48 public:
SetUpTestCase()49   static void SetUpTestCase() {
50     LLVMInitializeNativeTarget();
51     LLVMInitializeNativeAsmParser();
52     LLVMInitializeNativeAsmPrinter();
53 
54     // Attempt to set up a JIT instance once to verify that we can.
55     LLVMOrcJITTargetMachineBuilderRef JTMB = nullptr;
56     if (LLVMErrorRef E = LLVMOrcJITTargetMachineBuilderDetectHost(&JTMB)) {
57       // If setup fails then disable these tests.
58       LLVMConsumeError(E);
59       TargetSupported = false;
60       return;
61     }
62 
63     // Capture the target triple. We'll use it for both verification that
64     // this target is *supposed* to be supported, and error messages in
65     // the case that it fails anyway.
66     char *TT = LLVMOrcJITTargetMachineBuilderGetTargetTriple(JTMB);
67     TargetTriple = TT;
68     LLVMDisposeMessage(TT);
69 
70     if (!isSupported(TargetTriple)) {
71       // If this triple isn't supported then bail out.
72       TargetSupported = false;
73       LLVMOrcDisposeJITTargetMachineBuilder(JTMB);
74       return;
75     }
76 
77     LLVMOrcLLJITBuilderRef Builder = LLVMOrcCreateLLJITBuilder();
78     LLVMOrcLLJITBuilderSetJITTargetMachineBuilder(Builder, JTMB);
79     LLVMOrcLLJITRef J;
80     if (LLVMErrorRef E = LLVMOrcCreateLLJIT(&J, Builder)) {
81       // If setup fails then disable these tests.
82       TargetSupported = false;
83       LLVMConsumeError(E);
84       return;
85     }
86 
87     LLVMOrcDisposeLLJIT(J);
88     TargetSupported = true;
89   }
90 
SetUp()91   void SetUp() override {
92     if (!TargetSupported)
93       GTEST_SKIP();
94 
95     LLVMOrcJITTargetMachineBuilderRef JTMB = nullptr;
96     LLVMErrorRef E1 = LLVMOrcJITTargetMachineBuilderDetectHost(&JTMB);
97     assert(E1 == LLVMErrorSuccess && "Expected call to detect host to succeed");
98     (void)E1;
99 
100     LLVMOrcLLJITBuilderRef Builder = LLVMOrcCreateLLJITBuilder();
101     LLVMOrcLLJITBuilderSetJITTargetMachineBuilder(Builder, JTMB);
102     LLVMErrorRef E2 = LLVMOrcCreateLLJIT(&Jit, Builder);
103     assert(E2 == LLVMErrorSuccess &&
104            "Expected call to create LLJIT to succeed");
105     (void)E2;
106     ExecutionSession = LLVMOrcLLJITGetExecutionSession(Jit);
107     MainDylib = LLVMOrcLLJITGetMainJITDylib(Jit);
108   }
TearDown()109   void TearDown() override {
110     // Check whether Jit has already been torn down -- we allow clients to do
111     // this manually to check teardown behavior.
112     if (Jit) {
113       LLVMOrcDisposeLLJIT(Jit);
114       Jit = nullptr;
115     }
116   }
117 
118 protected:
isSupported(StringRef Triple)119   static bool isSupported(StringRef Triple) {
120     // TODO: Print error messages in failure logs, use them to audit this list.
121     // Some architectures may be unsupportable or missing key components, but
122     // some may just be failing due to bugs in this testcase.
123     if (Triple.startswith("armv7") || Triple.startswith("armv8l"))
124       return false;
125     llvm::Triple T(Triple);
126     if (T.isOSAIX() && T.isPPC64())
127       return false;
128     return true;
129   }
130 
materializationUnitFn()131   static void materializationUnitFn() {}
132 
133   // Stub definition generator, where all Names are materialized from the
134   // materializationUnitFn() test function and defined into the JIT Dylib
135   static LLVMErrorRef
definitionGeneratorFn(LLVMOrcDefinitionGeneratorRef G,void * Ctx,LLVMOrcLookupStateRef * LS,LLVMOrcLookupKind K,LLVMOrcJITDylibRef JD,LLVMOrcJITDylibLookupFlags F,LLVMOrcCLookupSet Names,size_t NamesCount)136   definitionGeneratorFn(LLVMOrcDefinitionGeneratorRef G, void *Ctx,
137                         LLVMOrcLookupStateRef *LS, LLVMOrcLookupKind K,
138                         LLVMOrcJITDylibRef JD, LLVMOrcJITDylibLookupFlags F,
139                         LLVMOrcCLookupSet Names, size_t NamesCount) {
140     for (size_t I = 0; I < NamesCount; I++) {
141       LLVMOrcCLookupSetElement Element = Names[I];
142       LLVMOrcJITTargetAddress Addr =
143           (LLVMOrcJITTargetAddress)(&materializationUnitFn);
144       LLVMJITSymbolFlags Flags = {LLVMJITSymbolGenericFlagsWeak, 0};
145       LLVMJITEvaluatedSymbol Sym = {Addr, Flags};
146       LLVMOrcRetainSymbolStringPoolEntry(Element.Name);
147       LLVMOrcCSymbolMapPair Pair = {Element.Name, Sym};
148       LLVMOrcCSymbolMapPair Pairs[] = {Pair};
149       LLVMOrcMaterializationUnitRef MU = LLVMOrcAbsoluteSymbols(Pairs, 1);
150       LLVMErrorRef Err = LLVMOrcJITDylibDefine(JD, MU);
151       if (Err)
152         return Err;
153     }
154     return LLVMErrorSuccess;
155   }
156 
createSMDiagnosticError(llvm::SMDiagnostic & Diag)157   static Error createSMDiagnosticError(llvm::SMDiagnostic &Diag) {
158     std::string Msg;
159     {
160       raw_string_ostream OS(Msg);
161       Diag.print("", OS);
162     }
163     return make_error<StringError>(std::move(Msg), inconvertibleErrorCode());
164   }
165 
166   // Create an LLVM IR module from the given StringRef.
167   static Expected<std::unique_ptr<Module>>
parseTestModule(LLVMContext & Ctx,StringRef Source,StringRef Name)168   parseTestModule(LLVMContext &Ctx, StringRef Source, StringRef Name) {
169     assert(TargetSupported &&
170            "Attempted to create module for unsupported target");
171     SMDiagnostic Err;
172     if (auto M = parseIR(MemoryBufferRef(Source, Name), Err, Ctx))
173       return std::move(M);
174     return createSMDiagnosticError(Err);
175   }
176 
177   // returns the sum of its two parameters
createTestModule(StringRef Source,StringRef Name)178   static LLVMOrcThreadSafeModuleRef createTestModule(StringRef Source,
179                                                      StringRef Name) {
180     auto Ctx = std::make_unique<LLVMContext>();
181     auto M = cantFail(parseTestModule(*Ctx, Source, Name));
182     return wrap(new ThreadSafeModule(std::move(M), std::move(Ctx)));
183   }
184 
createTestObject(StringRef Source,StringRef Name)185   static LLVMMemoryBufferRef createTestObject(StringRef Source,
186                                               StringRef Name) {
187     auto Ctx = std::make_unique<LLVMContext>();
188     auto M = cantFail(parseTestModule(*Ctx, Source, Name));
189 
190     auto JTMB = cantFail(JITTargetMachineBuilder::detectHost());
191     M->setDataLayout(cantFail(JTMB.getDefaultDataLayoutForTarget()));
192     auto TM = cantFail(JTMB.createTargetMachine());
193 
194     SimpleCompiler SC(*TM);
195     auto ObjBuffer = cantFail(SC(*M));
196     return wrap(ObjBuffer.release());
197   }
198 
199   static std::string TargetTriple;
200   static bool TargetSupported;
201 };
202 
203 std::string OrcCAPITestBase::TargetTriple;
204 bool OrcCAPITestBase::TargetSupported = false;
205 
206 namespace {
207 
208 constexpr StringRef SumExample =
209     R"(
210     define i32 @sum(i32 %x, i32 %y) {
211     entry:
212       %r = add nsw i32 %x, %y
213       ret i32 %r
214     }
215   )";
216 
217 } // end anonymous namespace.
218 
219 // Consumes the given error ref and returns the string error message.
toString(LLVMErrorRef E)220 static std::string toString(LLVMErrorRef E) {
221   char *ErrMsg = LLVMGetErrorMessage(E);
222   std::string Result(ErrMsg);
223   LLVMDisposeErrorMessage(ErrMsg);
224   return Result;
225 }
226 
TEST_F(OrcCAPITestBase,SymbolStringPoolUniquing)227 TEST_F(OrcCAPITestBase, SymbolStringPoolUniquing) {
228   LLVMOrcSymbolStringPoolEntryRef E1 =
229       LLVMOrcExecutionSessionIntern(ExecutionSession, "aaa");
230   LLVMOrcSymbolStringPoolEntryRef E2 =
231       LLVMOrcExecutionSessionIntern(ExecutionSession, "aaa");
232   LLVMOrcSymbolStringPoolEntryRef E3 =
233       LLVMOrcExecutionSessionIntern(ExecutionSession, "bbb");
234   const char *SymbolName = LLVMOrcSymbolStringPoolEntryStr(E1);
235   ASSERT_EQ(E1, E2) << "String pool entries are not unique";
236   ASSERT_NE(E1, E3) << "Unique symbol pool entries are equal";
237   ASSERT_STREQ("aaa", SymbolName) << "String value of symbol is not equal";
238   LLVMOrcReleaseSymbolStringPoolEntry(E1);
239   LLVMOrcReleaseSymbolStringPoolEntry(E2);
240   LLVMOrcReleaseSymbolStringPoolEntry(E3);
241 }
242 
TEST_F(OrcCAPITestBase,JITDylibLookup)243 TEST_F(OrcCAPITestBase, JITDylibLookup) {
244   LLVMOrcJITDylibRef DoesNotExist =
245       LLVMOrcExecutionSessionGetJITDylibByName(ExecutionSession, "test");
246   ASSERT_FALSE(!!DoesNotExist);
247   LLVMOrcJITDylibRef L1 =
248       LLVMOrcExecutionSessionCreateBareJITDylib(ExecutionSession, "test");
249   LLVMOrcJITDylibRef L2 =
250       LLVMOrcExecutionSessionGetJITDylibByName(ExecutionSession, "test");
251   ASSERT_EQ(L1, L2) << "Located JIT Dylib is not equal to original";
252 }
253 
TEST_F(OrcCAPITestBase,MaterializationUnitCreation)254 TEST_F(OrcCAPITestBase, MaterializationUnitCreation) {
255   LLVMOrcSymbolStringPoolEntryRef Name =
256       LLVMOrcLLJITMangleAndIntern(Jit, "test");
257   LLVMJITSymbolFlags Flags = {LLVMJITSymbolGenericFlagsWeak, 0};
258   LLVMOrcJITTargetAddress Addr =
259       (LLVMOrcJITTargetAddress)(&materializationUnitFn);
260   LLVMJITEvaluatedSymbol Sym = {Addr, Flags};
261   LLVMOrcCSymbolMapPair Pair = {Name, Sym};
262   LLVMOrcCSymbolMapPair Pairs[] = {Pair};
263   LLVMOrcMaterializationUnitRef MU = LLVMOrcAbsoluteSymbols(Pairs, 1);
264   if (LLVMErrorRef E = LLVMOrcJITDylibDefine(MainDylib, MU))
265     FAIL() << "Unexpected error while adding \"test\" symbol (triple = "
266            << TargetTriple << "): " << toString(E);
267   LLVMOrcJITTargetAddress OutAddr;
268   if (LLVMErrorRef E = LLVMOrcLLJITLookup(Jit, &OutAddr, "test"))
269     FAIL() << "Failed to look up \"test\" symbol (triple = " << TargetTriple
270            << "): " << toString(E);
271   ASSERT_EQ(Addr, OutAddr);
272 }
273 
274 struct ExecutionSessionLookupHelper {
275   bool ExpectSuccess = true;
276   bool CallbackReceived = false;
277   size_t NumExpectedPairs;
278   LLVMOrcCSymbolMapPair *ExpectedMapping;
279 };
280 
executionSessionLookupHandlerCallback(LLVMErrorRef Err,LLVMOrcCSymbolMapPairs Result,size_t NumPairs,void * RawCtx)281 static void executionSessionLookupHandlerCallback(LLVMErrorRef Err,
282                                                   LLVMOrcCSymbolMapPairs Result,
283                                                   size_t NumPairs,
284                                                   void *RawCtx) {
285   auto *Ctx = static_cast<ExecutionSessionLookupHelper *>(RawCtx);
286   Ctx->CallbackReceived = true;
287   if (Ctx->ExpectSuccess) {
288     EXPECT_THAT_ERROR(unwrap(Err), Succeeded());
289     EXPECT_EQ(NumPairs, Ctx->NumExpectedPairs)
290         << "Expected " << Ctx->NumExpectedPairs << " entries in result, got "
291         << NumPairs;
292     auto ExpectedMappingEnd = Ctx->ExpectedMapping + Ctx->NumExpectedPairs;
293     for (unsigned I = 0; I != NumPairs; ++I) {
294       auto J =
295           std::find_if(Ctx->ExpectedMapping, ExpectedMappingEnd,
296                        [N = Result[I].Name](const LLVMOrcCSymbolMapPair &Val) {
297                          return Val.Name == N;
298                        });
299       EXPECT_NE(J, ExpectedMappingEnd)
300           << "Missing symbol \""
301           << LLVMOrcSymbolStringPoolEntryStr(Result[I].Name) << "\"";
302       if (J != ExpectedMappingEnd) {
303         EXPECT_EQ(Result[I].Sym.Address, J->Sym.Address)
304             << "Result map for \"" << Result[I].Name
305             << "\" differs from expected value: "
306             << formatv("{0:x} vs {1:x}", Result[I].Sym.Address, J->Sym.Address);
307       }
308     }
309   } else
310     EXPECT_THAT_ERROR(unwrap(Err), Failed());
311 }
312 
TEST_F(OrcCAPITestBase,ExecutionSessionLookup_Success)313 TEST_F(OrcCAPITestBase, ExecutionSessionLookup_Success) {
314   // Test a successful generic lookup. We will look up three symbols over two
315   // JITDylibs: { "Foo" (Required), "Bar" (Weakly-ref), "Baz" (Required) } over
316   // { MainJITDylib (Exported-only), ExtraJD (All symbols) }.
317   //
318   // Foo will be defined as exported in MainJD.
319   // Bar will be defined as non-exported in MainJD.
320   // Baz will be defined as non-exported in ExtraJD.
321   //
322   // This will require (1) that we find the regular exported symbol Foo in
323   // MainJD, (2) that we *don't* find the non-exported symbol Bar in MainJD
324   // but also don't error (since it's weakly referenced), and (3) that we
325   // find the non-exported symbol Baz in ExtraJD (since we're searching all
326   // symbols in ExtraJD).
327 
328   ExecutionSessionLookupHelper H;
329   LLVMOrcSymbolStringPoolEntryRef Foo = LLVMOrcLLJITMangleAndIntern(Jit, "Foo");
330   LLVMOrcSymbolStringPoolEntryRef Bar = LLVMOrcLLJITMangleAndIntern(Jit, "Bar");
331   LLVMOrcSymbolStringPoolEntryRef Baz = LLVMOrcLLJITMangleAndIntern(Jit, "Baz");
332 
333   // Create ExtraJD.
334   LLVMOrcJITDylibRef ExtraJD = nullptr;
335   if (auto E = LLVMOrcExecutionSessionCreateJITDylib(ExecutionSession, &ExtraJD,
336                                                      "ExtraJD")) {
337     FAIL() << "Unexpected error while creating JITDylib \"ExtraJD\" (triple = "
338            << TargetTriple << "): " << toString(E);
339     return;
340   }
341 
342   // Add exported symbols "Foo" and "Bar" to Main JITDylib.
343   LLVMOrcRetainSymbolStringPoolEntry(Foo);
344   LLVMOrcRetainSymbolStringPoolEntry(Bar);
345   LLVMOrcCSymbolMapPair MainJDPairs[] = {
346       {Foo, {0x1, {LLVMJITSymbolGenericFlagsExported, 0}}},
347       {Bar, {0x2, {LLVMJITSymbolGenericFlagsNone, 0}}}};
348   LLVMOrcMaterializationUnitRef MainJDMU =
349       LLVMOrcAbsoluteSymbols(MainJDPairs, 2);
350   if (LLVMErrorRef E = LLVMOrcJITDylibDefine(MainDylib, MainJDMU))
351     FAIL() << "Unexpected error while adding MainDylib symbols (triple = "
352            << TargetTriple << "): " << toString(E);
353 
354   // Add non-exported symbol "Baz" to ExtraJD.
355   LLVMOrcRetainSymbolStringPoolEntry(Baz);
356   LLVMOrcCSymbolMapPair ExtraJDPairs[] = {
357       {Baz, {0x3, {LLVMJITSymbolGenericFlagsNone, 0}}}};
358   LLVMOrcMaterializationUnitRef ExtraJDMU =
359       LLVMOrcAbsoluteSymbols(ExtraJDPairs, 1);
360   if (LLVMErrorRef E = LLVMOrcJITDylibDefine(ExtraJD, ExtraJDMU))
361     FAIL() << "Unexpected error while adding ExtraJD symbols (triple = "
362            << TargetTriple << "): " << toString(E);
363 
364   // Create expected mapping for result:
365   LLVMOrcCSymbolMapPair ExpectedMapping[] = {
366       {Foo, {0x1, {LLVMJITSymbolGenericFlagsExported, 0}}},
367       {Baz, {0x3, {LLVMJITSymbolGenericFlagsNone, 0}}}};
368   H.ExpectedMapping = ExpectedMapping;
369   H.NumExpectedPairs = 2;
370 
371   // Issue the lookup. We're using the default same-thread dispatch, so the
372   // handler should have run by the time we return from this call.
373   LLVMOrcCJITDylibSearchOrderElement SO[] = {
374       {MainDylib, LLVMOrcJITDylibLookupFlagsMatchExportedSymbolsOnly},
375       {ExtraJD, LLVMOrcJITDylibLookupFlagsMatchAllSymbols}};
376 
377   LLVMOrcRetainSymbolStringPoolEntry(Foo);
378   LLVMOrcRetainSymbolStringPoolEntry(Bar);
379   LLVMOrcRetainSymbolStringPoolEntry(Baz);
380   LLVMOrcCLookupSetElement LS[] = {
381       {Foo, LLVMOrcSymbolLookupFlagsRequiredSymbol},
382       {Bar, LLVMOrcSymbolLookupFlagsWeaklyReferencedSymbol},
383       {Baz, LLVMOrcSymbolLookupFlagsRequiredSymbol}};
384   LLVMOrcExecutionSessionLookup(ExecutionSession, LLVMOrcLookupKindStatic, SO,
385                                 2, LS, 3, executionSessionLookupHandlerCallback,
386                                 &H);
387 
388   EXPECT_TRUE(H.CallbackReceived) << "Lookup callback never received";
389 
390   // Release our local string ptrs.
391   LLVMOrcReleaseSymbolStringPoolEntry(Baz);
392   LLVMOrcReleaseSymbolStringPoolEntry(Bar);
393   LLVMOrcReleaseSymbolStringPoolEntry(Foo);
394 }
395 
TEST_F(OrcCAPITestBase,ExecutionSessionLookup_Failure)396 TEST_F(OrcCAPITestBase, ExecutionSessionLookup_Failure) {
397   // Test generic lookup failure case. We will look up a symbol in MainDylib
398   // without defining it. We expect this to result in a symbol-not-found error.
399 
400   ExecutionSessionLookupHelper H;
401   H.ExpectSuccess = false;
402 
403   LLVMOrcCJITDylibSearchOrderElement SO[] = {
404       {MainDylib, LLVMOrcJITDylibLookupFlagsMatchExportedSymbolsOnly}};
405   LLVMOrcCLookupSetElement LS[] = {{LLVMOrcLLJITMangleAndIntern(Jit, "Foo"),
406                                     LLVMOrcSymbolLookupFlagsRequiredSymbol}};
407   LLVMOrcExecutionSessionLookup(ExecutionSession, LLVMOrcLookupKindStatic, SO,
408                                 1, LS, 1, executionSessionLookupHandlerCallback,
409                                 &H);
410 
411   EXPECT_TRUE(H.CallbackReceived) << "Lookup callback never received";
412 }
413 
TEST_F(OrcCAPITestBase,DefinitionGenerators)414 TEST_F(OrcCAPITestBase, DefinitionGenerators) {
415   LLVMOrcDefinitionGeneratorRef Gen =
416       LLVMOrcCreateCustomCAPIDefinitionGenerator(&definitionGeneratorFn,
417                                                  nullptr, nullptr);
418   LLVMOrcJITDylibAddGenerator(MainDylib, Gen);
419   LLVMOrcJITTargetAddress OutAddr;
420   if (LLVMErrorRef E = LLVMOrcLLJITLookup(Jit, &OutAddr, "test"))
421     FAIL() << "The DefinitionGenerator did not create symbol \"test\" "
422            << "(triple = " << TargetTriple << "): " << toString(E);
423   LLVMOrcJITTargetAddress ExpectedAddr =
424       (LLVMOrcJITTargetAddress)(&materializationUnitFn);
425   ASSERT_EQ(ExpectedAddr, OutAddr);
426 }
427 
TEST_F(OrcCAPITestBase,ResourceTrackerDefinitionLifetime)428 TEST_F(OrcCAPITestBase, ResourceTrackerDefinitionLifetime) {
429   // This test case ensures that all symbols loaded into a JITDylib with a
430   // ResourceTracker attached are cleared from the JITDylib once the RT is
431   // removed.
432   LLVMOrcResourceTrackerRef RT =
433       LLVMOrcJITDylibCreateResourceTracker(MainDylib);
434   LLVMOrcThreadSafeModuleRef TSM = createTestModule(SumExample, "sum.ll");
435   if (LLVMErrorRef E = LLVMOrcLLJITAddLLVMIRModuleWithRT(Jit, RT, TSM))
436     FAIL() << "Failed to add LLVM IR module to LLJIT (triple = " << TargetTriple
437            << "): " << toString(E);
438   LLVMOrcJITTargetAddress TestFnAddr;
439   if (LLVMErrorRef E = LLVMOrcLLJITLookup(Jit, &TestFnAddr, "sum"))
440     FAIL() << "Symbol \"sum\" was not added into JIT (triple = " << TargetTriple
441            << "): " << toString(E);
442   ASSERT_TRUE(!!TestFnAddr);
443   LLVMOrcResourceTrackerRemove(RT);
444   LLVMOrcJITTargetAddress OutAddr;
445   LLVMErrorRef Err = LLVMOrcLLJITLookup(Jit, &OutAddr, "sum");
446   ASSERT_TRUE(Err);
447   LLVMConsumeError(Err);
448 
449   ASSERT_FALSE(OutAddr);
450   LLVMOrcReleaseResourceTracker(RT);
451 }
452 
TEST_F(OrcCAPITestBase,ResourceTrackerTransfer)453 TEST_F(OrcCAPITestBase, ResourceTrackerTransfer) {
454   LLVMOrcResourceTrackerRef DefaultRT =
455       LLVMOrcJITDylibGetDefaultResourceTracker(MainDylib);
456   LLVMOrcResourceTrackerRef RT2 =
457       LLVMOrcJITDylibCreateResourceTracker(MainDylib);
458   LLVMOrcThreadSafeModuleRef TSM = createTestModule(SumExample, "sum.ll");
459   if (LLVMErrorRef E = LLVMOrcLLJITAddLLVMIRModuleWithRT(Jit, DefaultRT, TSM))
460     FAIL() << "Failed to add LLVM IR module to LLJIT (triple = " << TargetTriple
461            << "): " << toString(E);
462   LLVMOrcJITTargetAddress Addr;
463   if (LLVMErrorRef E = LLVMOrcLLJITLookup(Jit, &Addr, "sum"))
464     FAIL() << "Symbol \"sum\" was not added into JIT (triple = " << TargetTriple
465            << "): " << toString(E);
466   LLVMOrcResourceTrackerTransferTo(DefaultRT, RT2);
467   LLVMErrorRef Err = LLVMOrcLLJITLookup(Jit, &Addr, "sum");
468   ASSERT_FALSE(Err);
469   LLVMOrcReleaseResourceTracker(RT2);
470 }
471 
TEST_F(OrcCAPITestBase,AddObjectBuffer)472 TEST_F(OrcCAPITestBase, AddObjectBuffer) {
473   LLVMOrcObjectLayerRef ObjLinkingLayer = LLVMOrcLLJITGetObjLinkingLayer(Jit);
474   LLVMMemoryBufferRef ObjBuffer = createTestObject(SumExample, "sum.ll");
475 
476   if (LLVMErrorRef E = LLVMOrcObjectLayerAddObjectFile(ObjLinkingLayer,
477                                                        MainDylib, ObjBuffer))
478     FAIL() << "Failed to add object file to ObjLinkingLayer (triple = "
479            << TargetTriple << "): " << toString(E);
480 
481   LLVMOrcJITTargetAddress SumAddr;
482   if (LLVMErrorRef E = LLVMOrcLLJITLookup(Jit, &SumAddr, "sum"))
483     FAIL() << "Symbol \"sum\" was not added into JIT (triple = " << TargetTriple
484            << "): " << toString(E);
485   ASSERT_TRUE(!!SumAddr);
486 }
487 
TEST_F(OrcCAPITestBase,ExecutionTest)488 TEST_F(OrcCAPITestBase, ExecutionTest) {
489   using SumFunctionType = int32_t (*)(int32_t, int32_t);
490 
491   // This test performs OrcJIT compilation of a simple sum module
492   LLVMInitializeNativeAsmPrinter();
493   LLVMOrcThreadSafeModuleRef TSM = createTestModule(SumExample, "sum.ll");
494   if (LLVMErrorRef E = LLVMOrcLLJITAddLLVMIRModule(Jit, MainDylib, TSM))
495     FAIL() << "Failed to add LLVM IR module to LLJIT (triple = " << TargetTriple
496            << ")" << toString(E);
497   LLVMOrcJITTargetAddress TestFnAddr;
498   if (LLVMErrorRef E = LLVMOrcLLJITLookup(Jit, &TestFnAddr, "sum"))
499     FAIL() << "Symbol \"sum\" was not added into JIT (triple = " << TargetTriple
500            << "): " << toString(E);
501   auto *SumFn = (SumFunctionType)(TestFnAddr);
502   int32_t Result = SumFn(1, 1);
503   ASSERT_EQ(2, Result);
504 }
505 
Destroy(void * Ctx)506 void Destroy(void *Ctx) {}
507 
TargetFn()508 void TargetFn() {}
509 
Materialize(void * Ctx,LLVMOrcMaterializationResponsibilityRef MR)510 void Materialize(void *Ctx, LLVMOrcMaterializationResponsibilityRef MR) {
511   LLVMOrcJITDylibRef JD =
512       LLVMOrcMaterializationResponsibilityGetTargetDylib(MR);
513   ASSERT_TRUE(!!JD);
514 
515   LLVMOrcExecutionSessionRef ES =
516       LLVMOrcMaterializationResponsibilityGetExecutionSession(MR);
517   ASSERT_TRUE(!!ES);
518 
519   LLVMOrcSymbolStringPoolEntryRef InitSym =
520       LLVMOrcMaterializationResponsibilityGetInitializerSymbol(MR);
521   ASSERT_TRUE(!InitSym);
522 
523   size_t NumSymbols;
524   LLVMOrcCSymbolFlagsMapPairs Symbols =
525       LLVMOrcMaterializationResponsibilityGetSymbols(MR, &NumSymbols);
526 
527   ASSERT_TRUE(!!Symbols);
528   ASSERT_EQ(NumSymbols, (size_t)1);
529 
530   LLVMOrcSymbolStringPoolEntryRef *RequestedSymbols =
531       LLVMOrcMaterializationResponsibilityGetRequestedSymbols(MR, &NumSymbols);
532 
533   ASSERT_TRUE(!!RequestedSymbols);
534   ASSERT_EQ(NumSymbols, (size_t)1);
535 
536   LLVMOrcCSymbolFlagsMapPair TargetSym = Symbols[0];
537 
538   ASSERT_EQ(RequestedSymbols[0], TargetSym.Name);
539   LLVMOrcRetainSymbolStringPoolEntry(TargetSym.Name);
540 
541   LLVMOrcDisposeCSymbolFlagsMap(Symbols);
542   LLVMOrcDisposeSymbols(RequestedSymbols);
543 
544   LLVMOrcJITTargetAddress Addr = (LLVMOrcJITTargetAddress)(&TargetFn);
545 
546   LLVMJITSymbolFlags Flags = {
547       LLVMJITSymbolGenericFlagsExported | LLVMJITSymbolGenericFlagsCallable, 0};
548   ASSERT_EQ(TargetSym.Flags.GenericFlags, Flags.GenericFlags);
549   ASSERT_EQ(TargetSym.Flags.TargetFlags, Flags.TargetFlags);
550 
551   LLVMJITEvaluatedSymbol Sym = {Addr, Flags};
552 
553   LLVMOrcLLJITRef J = (LLVMOrcLLJITRef)Ctx;
554 
555   LLVMOrcSymbolStringPoolEntryRef OtherSymbol =
556       LLVMOrcLLJITMangleAndIntern(J, "other");
557   LLVMOrcSymbolStringPoolEntryRef DependencySymbol =
558       LLVMOrcLLJITMangleAndIntern(J, "dependency");
559 
560   LLVMOrcRetainSymbolStringPoolEntry(OtherSymbol);
561   LLVMOrcRetainSymbolStringPoolEntry(DependencySymbol);
562   LLVMOrcCSymbolFlagsMapPair NewSymbols[] = {
563       {OtherSymbol, Flags},
564       {DependencySymbol, Flags},
565   };
566   LLVMOrcMaterializationResponsibilityDefineMaterializing(MR, NewSymbols, 2);
567 
568   LLVMOrcRetainSymbolStringPoolEntry(OtherSymbol);
569   LLVMOrcMaterializationResponsibilityRef OtherMR = NULL;
570   {
571     LLVMErrorRef Err = LLVMOrcMaterializationResponsibilityDelegate(
572         MR, &OtherSymbol, 1, &OtherMR);
573     if (Err) {
574       char *ErrMsg = LLVMGetErrorMessage(Err);
575       fprintf(stderr, "Error: %s\n", ErrMsg);
576       LLVMDisposeErrorMessage(ErrMsg);
577       LLVMOrcMaterializationResponsibilityFailMaterialization(MR);
578       LLVMOrcDisposeMaterializationResponsibility(MR);
579       return;
580     }
581   }
582   assert(OtherMR);
583 
584   LLVMOrcCSymbolMapPair OtherPair = {OtherSymbol, Sym};
585   LLVMOrcMaterializationUnitRef OtherMU = LLVMOrcAbsoluteSymbols(&OtherPair, 1);
586   // OtherSymbol is no longer owned by us
587   {
588     LLVMErrorRef Err =
589         LLVMOrcMaterializationResponsibilityReplace(OtherMR, OtherMU);
590     if (Err) {
591       char *ErrMsg = LLVMGetErrorMessage(Err);
592       fprintf(stderr, "Error: %s\n", ErrMsg);
593       LLVMDisposeErrorMessage(ErrMsg);
594 
595       LLVMOrcMaterializationResponsibilityFailMaterialization(OtherMR);
596       LLVMOrcMaterializationResponsibilityFailMaterialization(MR);
597 
598       LLVMOrcDisposeMaterializationResponsibility(OtherMR);
599       LLVMOrcDisposeMaterializationResponsibility(MR);
600       LLVMOrcDisposeMaterializationUnit(OtherMU);
601       return;
602     }
603   }
604   LLVMOrcDisposeMaterializationResponsibility(OtherMR);
605 
606   // FIXME: Implement async lookup
607   // A real test of the dependence tracking in the success case would require
608   // async lookups. You could:
609   // 1. Materialize foo, making foo depend on other.
610   // 2. In the caller, verify that the lookup callback for foo has not run (due
611   // to the dependence)
612   // 3. Materialize other by looking it up.
613   // 4. In the caller, verify that the lookup callback for foo has now run.
614 
615   LLVMOrcRetainSymbolStringPoolEntry(TargetSym.Name);
616   LLVMOrcRetainSymbolStringPoolEntry(DependencySymbol);
617   LLVMOrcCDependenceMapPair Dependency = {JD, {&DependencySymbol, 1}};
618   LLVMOrcMaterializationResponsibilityAddDependencies(MR, TargetSym.Name,
619                                                       &Dependency, 1);
620 
621   LLVMOrcRetainSymbolStringPoolEntry(DependencySymbol);
622   LLVMOrcMaterializationResponsibilityAddDependenciesForAll(MR, &Dependency, 1);
623 
624   // See FIXME above
625   LLVMOrcCSymbolMapPair Pair = {DependencySymbol, Sym};
626   LLVMOrcMaterializationResponsibilityNotifyResolved(MR, &Pair, 1);
627   // DependencySymbol no longer owned by us
628 
629   Pair = {TargetSym.Name, Sym};
630   LLVMOrcMaterializationResponsibilityNotifyResolved(MR, &Pair, 1);
631 
632   LLVMOrcMaterializationResponsibilityNotifyEmitted(MR);
633   LLVMOrcDisposeMaterializationResponsibility(MR);
634   return;
635 }
636 
TEST_F(OrcCAPITestBase,MaterializationResponsibility)637 TEST_F(OrcCAPITestBase, MaterializationResponsibility) {
638   LLVMJITSymbolFlags Flags = {
639       LLVMJITSymbolGenericFlagsExported | LLVMJITSymbolGenericFlagsCallable, 0};
640   LLVMOrcCSymbolFlagsMapPair Sym = {LLVMOrcLLJITMangleAndIntern(Jit, "foo"),
641                                     Flags};
642 
643   LLVMOrcMaterializationUnitRef MU = LLVMOrcCreateCustomMaterializationUnit(
644       "MU", (void *)Jit, &Sym, 1, NULL, &Materialize, NULL, &Destroy);
645   LLVMOrcJITDylibRef JD = LLVMOrcLLJITGetMainJITDylib(Jit);
646   LLVMOrcJITDylibDefine(JD, MU);
647 
648   LLVMOrcJITTargetAddress Addr;
649   if (LLVMErrorRef Err = LLVMOrcLLJITLookup(Jit, &Addr, "foo")) {
650     FAIL() << "foo was not materialized " << toString(Err);
651   }
652   ASSERT_TRUE(!!Addr);
653   ASSERT_EQ(Addr, (LLVMOrcJITTargetAddress)&TargetFn);
654 
655   if (LLVMErrorRef Err = LLVMOrcLLJITLookup(Jit, &Addr, "other")) {
656     FAIL() << "other was not materialized " << toString(Err);
657   }
658   ASSERT_TRUE(!!Addr);
659   ASSERT_EQ(Addr, (LLVMOrcJITTargetAddress)&TargetFn);
660 
661   if (LLVMErrorRef Err = LLVMOrcLLJITLookup(Jit, &Addr, "dependency")) {
662     FAIL() << "dependency was not materialized " << toString(Err);
663   }
664   ASSERT_TRUE(!!Addr);
665   ASSERT_EQ(Addr, (LLVMOrcJITTargetAddress)&TargetFn);
666 }
667 
668 struct SuspendedLookupContext {
669   std::function<void()> AsyncWork;
670   LLVMOrcSymbolStringPoolEntryRef NameToGenerate;
671   JITTargetAddress AddrToGenerate;
672 
673   bool Disposed = false;
674   bool QueryCompleted = true;
675 };
676 
TryToGenerateWithSuspendedLookup(LLVMOrcDefinitionGeneratorRef GeneratorObj,void * RawCtx,LLVMOrcLookupStateRef * LookupState,LLVMOrcLookupKind Kind,LLVMOrcJITDylibRef JD,LLVMOrcJITDylibLookupFlags JDLookupFlags,LLVMOrcCLookupSet LookupSet,size_t LookupSetSize)677 static LLVMErrorRef TryToGenerateWithSuspendedLookup(
678     LLVMOrcDefinitionGeneratorRef GeneratorObj, void *RawCtx,
679     LLVMOrcLookupStateRef *LookupState, LLVMOrcLookupKind Kind,
680     LLVMOrcJITDylibRef JD, LLVMOrcJITDylibLookupFlags JDLookupFlags,
681     LLVMOrcCLookupSet LookupSet, size_t LookupSetSize) {
682 
683   auto *Ctx = static_cast<SuspendedLookupContext *>(RawCtx);
684 
685   assert(LookupSetSize == 1);
686   assert(LookupSet[0].Name == Ctx->NameToGenerate);
687 
688   LLVMJITEvaluatedSymbol Sym = {0x1234, {LLVMJITSymbolGenericFlagsExported, 0}};
689   LLVMOrcRetainSymbolStringPoolEntry(LookupSet[0].Name);
690   LLVMOrcCSymbolMapPair Pair = {LookupSet[0].Name, Sym};
691   LLVMOrcCSymbolMapPair Pairs[] = {Pair};
692   LLVMOrcMaterializationUnitRef MU = LLVMOrcAbsoluteSymbols(Pairs, 1);
693 
694   // Capture and reset LookupState to suspend the lookup. We'll continue it in
695   // the SuspendedLookup testcase below.
696   Ctx->AsyncWork = [LS = *LookupState, JD, MU]() {
697     LLVMErrorRef Err = LLVMOrcJITDylibDefine(JD, MU);
698     LLVMOrcLookupStateContinueLookup(LS, Err);
699   };
700   *LookupState = nullptr;
701   return LLVMErrorSuccess;
702 }
703 
DisposeSuspendedLookupContext(void * Ctx)704 static void DisposeSuspendedLookupContext(void *Ctx) {
705   static_cast<SuspendedLookupContext *>(Ctx)->Disposed = true;
706 }
707 
708 static void
suspendLookupTestLookupHandlerCallback(LLVMErrorRef Err,LLVMOrcCSymbolMapPairs Result,size_t NumPairs,void * RawCtx)709 suspendLookupTestLookupHandlerCallback(LLVMErrorRef Err,
710                                        LLVMOrcCSymbolMapPairs Result,
711                                        size_t NumPairs, void *RawCtx) {
712   if (Err) {
713     FAIL() << "Suspended DefinitionGenerator did not create symbol \"foo\": "
714            << toString(Err);
715     return;
716   }
717 
718   EXPECT_EQ(NumPairs, 1U)
719       << "Unexpected number of result entries: expected 1, got " << NumPairs;
720 
721   auto *Ctx = static_cast<SuspendedLookupContext *>(RawCtx);
722   EXPECT_EQ(Result[0].Name, Ctx->NameToGenerate);
723   EXPECT_EQ(Result[0].Sym.Address, Ctx->AddrToGenerate);
724 
725   Ctx->QueryCompleted = true;
726 }
727 
TEST_F(OrcCAPITestBase,SuspendedLookup)728 TEST_F(OrcCAPITestBase, SuspendedLookup) {
729   // Test that we can suspend lookup in a custom generator.
730   SuspendedLookupContext Ctx;
731   Ctx.NameToGenerate = LLVMOrcLLJITMangleAndIntern(Jit, "foo");
732   Ctx.AddrToGenerate = 0x1234;
733 
734   // Add generator.
735   LLVMOrcJITDylibAddGenerator(MainDylib,
736                               LLVMOrcCreateCustomCAPIDefinitionGenerator(
737                                   &TryToGenerateWithSuspendedLookup, &Ctx,
738                                   DisposeSuspendedLookupContext));
739 
740   // Expect no work to do before the lookup.
741   EXPECT_FALSE(Ctx.AsyncWork) << "Unexpected generator work before lookup";
742 
743   // Issue lookup. This should trigger the generator, but generation should
744   // be suspended.
745   LLVMOrcCJITDylibSearchOrderElement SO[] = {
746       {MainDylib, LLVMOrcJITDylibLookupFlagsMatchExportedSymbolsOnly}};
747   LLVMOrcRetainSymbolStringPoolEntry(Ctx.NameToGenerate);
748   LLVMOrcCLookupSetElement LS[] = {
749       {Ctx.NameToGenerate, LLVMOrcSymbolLookupFlagsRequiredSymbol}};
750   LLVMOrcExecutionSessionLookup(ExecutionSession, LLVMOrcLookupKindStatic, SO,
751                                 1, LS, 1,
752                                 suspendLookupTestLookupHandlerCallback, &Ctx);
753 
754   // Expect that we now have generator work to do.
755   EXPECT_TRUE(Ctx.AsyncWork)
756       << "Failed to generator (or failed to suspend generator)";
757 
758   // Do the work. This should allow the query to complete.
759   Ctx.AsyncWork();
760 
761   // Check that the query completed.
762   EXPECT_TRUE(Ctx.QueryCompleted);
763 
764   // Release our local copy of the string.
765   LLVMOrcReleaseSymbolStringPoolEntry(Ctx.NameToGenerate);
766 
767   // Explicitly tear down the JIT.
768   LLVMOrcDisposeLLJIT(Jit);
769   Jit = nullptr;
770 
771   // Check that the generator context was "destroyed".
772   EXPECT_TRUE(Ctx.Disposed);
773 }
774