1 //===- execution_engine.c - Test for the C bindings for the MLIR JIT-------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 /* RUN: mlir-capi-execution-engine-test 2>&1 | FileCheck %s
11  */
12 /* REQUIRES: host-supports-jit
13  */
14 
15 #include "mlir-c/Conversion.h"
16 #include "mlir-c/ExecutionEngine.h"
17 #include "mlir-c/IR.h"
18 #include "mlir-c/RegisterEverything.h"
19 
20 #include <assert.h>
21 #include <math.h>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <string.h>
25 
registerAllUpstreamDialects(MlirContext ctx)26 static void registerAllUpstreamDialects(MlirContext ctx) {
27   MlirDialectRegistry registry = mlirDialectRegistryCreate();
28   mlirRegisterAllDialects(registry);
29   mlirContextAppendDialectRegistry(ctx, registry);
30   mlirDialectRegistryDestroy(registry);
31 }
32 
lowerModuleToLLVM(MlirContext ctx,MlirModule module)33 void lowerModuleToLLVM(MlirContext ctx, MlirModule module) {
34   MlirPassManager pm = mlirPassManagerCreate(ctx);
35   MlirOpPassManager opm = mlirPassManagerGetNestedUnder(
36       pm, mlirStringRefCreateFromCString("func.func"));
37   mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVM());
38   mlirOpPassManagerAddOwnedPass(opm,
39                                 mlirCreateConversionConvertArithmeticToLLVM());
40   MlirLogicalResult status = mlirPassManagerRun(pm, module);
41   if (mlirLogicalResultIsFailure(status)) {
42     fprintf(stderr, "Unexpected failure running pass pipeline\n");
43     exit(2);
44   }
45   mlirPassManagerDestroy(pm);
46 }
47 
48 // CHECK-LABEL: Running test 'testSimpleExecution'
testSimpleExecution()49 void testSimpleExecution() {
50   MlirContext ctx = mlirContextCreate();
51   registerAllUpstreamDialects(ctx);
52 
53   MlirModule module = mlirModuleCreateParse(
54       ctx, mlirStringRefCreateFromCString(
55                // clang-format off
56 "module {                                                                    \n"
57 "  func.func @add(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {     \n"
58 "    %res = arith.addi %arg0, %arg0 : i32                                        \n"
59 "    return %res : i32                                                           \n"
60 "  }                                                                             \n"
61 "}"));
62   // clang-format on
63   lowerModuleToLLVM(ctx, module);
64   mlirRegisterAllLLVMTranslations(ctx);
65   MlirExecutionEngine jit = mlirExecutionEngineCreate(
66       module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL);
67   if (mlirExecutionEngineIsNull(jit)) {
68     fprintf(stderr, "Execution engine creation failed");
69     exit(2);
70   }
71   int input = 42;
72   int result = -1;
73   void *args[2] = {&input, &result};
74   if (mlirLogicalResultIsFailure(mlirExecutionEngineInvokePacked(
75           jit, mlirStringRefCreateFromCString("add"), args))) {
76     fprintf(stderr, "Execution engine creation failed");
77     abort();
78   }
79   // CHECK: Input: 42 Result: 84
80   printf("Input: %d Result: %d\n", input, result);
81   mlirExecutionEngineDestroy(jit);
82   mlirModuleDestroy(module);
83   mlirContextDestroy(ctx);
84 }
85 
main()86 int main() {
87 
88 #define _STRINGIFY(x) #x
89 #define STRINGIFY(x) _STRINGIFY(x)
90 #define TEST(test)                                                             \
91   printf("Running test '" STRINGIFY(test) "'\n");                              \
92   test();
93 
94   TEST(testSimpleExecution);
95   return 0;
96 }
97