1 //===- execution_engine.c - Test for the C bindings for the MLIR JIT-------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 /* RUN: mlir-capi-execution-engine-test 2>&1 | FileCheck %s
11  */
12 /* REQUIRES: native
13  */
14 
15 #include "mlir-c/Conversion.h"
16 #include "mlir-c/ExecutionEngine.h"
17 #include "mlir-c/IR.h"
18 #include "mlir-c/Registration.h"
19 
20 #include <assert.h>
21 #include <math.h>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <string.h>
25 
26 void lowerModuleToLLVM(MlirContext ctx, MlirModule module) {
27   MlirPassManager pm = mlirPassManagerCreate(ctx);
28   MlirOpPassManager opm = mlirPassManagerGetNestedUnder(
29       pm, mlirStringRefCreateFromCString("func.func"));
30   mlirPassManagerAddOwnedPass(pm, mlirCreateConversionConvertFuncToLLVM());
31   mlirOpPassManagerAddOwnedPass(opm,
32                                 mlirCreateConversionConvertArithmeticToLLVM());
33   MlirLogicalResult status = mlirPassManagerRun(pm, module);
34   if (mlirLogicalResultIsFailure(status)) {
35     fprintf(stderr, "Unexpected failure running pass pipeline\n");
36     exit(2);
37   }
38   mlirPassManagerDestroy(pm);
39 }
40 
41 // CHECK-LABEL: Running test 'testSimpleExecution'
42 void testSimpleExecution() {
43   MlirContext ctx = mlirContextCreate();
44   mlirRegisterAllDialects(ctx);
45   MlirModule module = mlirModuleCreateParse(
46       ctx, mlirStringRefCreateFromCString(
47                // clang-format off
48 "module {                                                                    \n"
49 "  func.func @add(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {     \n"
50 "    %res = arith.addi %arg0, %arg0 : i32                                        \n"
51 "    return %res : i32                                                           \n"
52 "  }                                                                             \n"
53 "}"));
54   // clang-format on
55   lowerModuleToLLVM(ctx, module);
56   mlirRegisterAllLLVMTranslations(ctx);
57   MlirExecutionEngine jit = mlirExecutionEngineCreate(
58       module, /*optLevel=*/2, /*numPaths=*/0, /*sharedLibPaths=*/NULL);
59   if (mlirExecutionEngineIsNull(jit)) {
60     fprintf(stderr, "Execution engine creation failed");
61     exit(2);
62   }
63   int input = 42;
64   int result = -1;
65   void *args[2] = {&input, &result};
66   if (mlirLogicalResultIsFailure(mlirExecutionEngineInvokePacked(
67           jit, mlirStringRefCreateFromCString("add"), args))) {
68     fprintf(stderr, "Execution engine creation failed");
69     abort();
70   }
71   // CHECK: Input: 42 Result: 84
72   printf("Input: %d Result: %d\n", input, result);
73   mlirExecutionEngineDestroy(jit);
74   mlirModuleDestroy(module);
75   mlirContextDestroy(ctx);
76 }
77 
78 int main() {
79 
80 #define _STRINGIFY(x) #x
81 #define STRINGIFY(x) _STRINGIFY(x)
82 #define TEST(test)                                                             \
83   printf("Running test '" STRINGIFY(test) "'\n");                              \
84   test();
85 
86   TEST(testSimpleExecution);
87   return 0;
88 }
89