1 //===- Invoke.cpp ------------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
11 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
12 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
14 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/ExecutionEngine/CRunnerUtils.h"
18 #include "mlir/ExecutionEngine/ExecutionEngine.h"
19 #include "mlir/ExecutionEngine/MemRefUtils.h"
20 #include "mlir/ExecutionEngine/RunnerUtils.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/InitAllDialects.h"
23 #include "mlir/Parser.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
26 #include "mlir/Target/LLVMIR/Export.h"
27 #include "llvm/Support/TargetSelect.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 #include "gmock/gmock.h"
31 
32 using namespace mlir;
33 
34 static struct LLVMInitializer {
35   LLVMInitializer() {
36     llvm::InitializeNativeTarget();
37     llvm::InitializeNativeTargetAsmPrinter();
38   }
39 } initializer;
40 
41 /// Simple conversion pipeline for the purpose of testing sources written in
42 /// dialects lowering to LLVM Dialect.
43 static LogicalResult lowerToLLVMDialect(ModuleOp module) {
44   PassManager pm(module.getContext());
45   pm.addPass(mlir::createMemRefToLLVMPass());
46   pm.addNestedPass<FuncOp>(mlir::arith::createConvertArithmeticToLLVMPass());
47   pm.addPass(mlir::createLowerToLLVMPass());
48   pm.addPass(mlir::createReconcileUnrealizedCastsPass());
49   return pm.run(module);
50 }
51 
52 // The JIT isn't supported on Windows at that time
53 #ifndef _WIN32
54 
55 TEST(MLIRExecutionEngine, AddInteger) {
56   std::string moduleStr = R"mlir(
57   func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
58     %res = arith.addi %arg0, %arg0 : i32
59     return %res : i32
60   }
61   )mlir";
62   DialectRegistry registry;
63   registerAllDialects(registry);
64   registerLLVMDialectTranslation(registry);
65   MLIRContext context(registry);
66   OwningModuleRef module = parseSourceString(moduleStr, &context);
67   ASSERT_TRUE(!!module);
68   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
69   auto jitOrError = ExecutionEngine::create(*module);
70   ASSERT_TRUE(!!jitOrError);
71   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
72   // The result of the function must be passed as output argument.
73   int result = 0;
74   llvm::Error error =
75       jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
76   ASSERT_TRUE(!error);
77   ASSERT_EQ(result, 42 + 42);
78 }
79 
80 TEST(MLIRExecutionEngine, SubtractFloat) {
81   std::string moduleStr = R"mlir(
82   func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
83     %res = arith.subf %arg0, %arg1 : f32
84     return %res : f32
85   }
86   )mlir";
87   DialectRegistry registry;
88   registerAllDialects(registry);
89   registerLLVMDialectTranslation(registry);
90   MLIRContext context(registry);
91   OwningModuleRef module = parseSourceString(moduleStr, &context);
92   ASSERT_TRUE(!!module);
93   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
94   auto jitOrError = ExecutionEngine::create(*module);
95   ASSERT_TRUE(!!jitOrError);
96   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
97   // The result of the function must be passed as output argument.
98   float result = -1;
99   llvm::Error error =
100       jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
101   ASSERT_TRUE(!error);
102   ASSERT_EQ(result, 42.f);
103 }
104 
105 TEST(NativeMemRefJit, ZeroRankMemref) {
106   OwningMemRef<float, 0> A({});
107   A[{}] = 42.;
108   ASSERT_EQ(*A->data, 42);
109   A[{}] = 0;
110   std::string moduleStr = R"mlir(
111   func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
112     %cst42 = arith.constant 42.0 : f32
113     memref.store %cst42, %arg0[] : memref<f32>
114     return
115   }
116   )mlir";
117   DialectRegistry registry;
118   registerAllDialects(registry);
119   registerLLVMDialectTranslation(registry);
120   MLIRContext context(registry);
121   auto module = parseSourceString(moduleStr, &context);
122   ASSERT_TRUE(!!module);
123   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
124   auto jitOrError = ExecutionEngine::create(*module);
125   ASSERT_TRUE(!!jitOrError);
126   auto jit = std::move(jitOrError.get());
127 
128   llvm::Error error = jit->invoke("zero_ranked", &*A);
129   ASSERT_TRUE(!error);
130   EXPECT_EQ((A[{}]), 42.);
131   for (float &elt : *A)
132     EXPECT_EQ(&elt, &(A[{}]));
133 }
134 
135 TEST(NativeMemRefJit, RankOneMemref) {
136   int64_t shape[] = {9};
137   OwningMemRef<float, 1> A(shape);
138   int count = 1;
139   for (float &elt : *A) {
140     EXPECT_EQ(&elt, &(A[{count - 1}]));
141     elt = count++;
142   }
143 
144   std::string moduleStr = R"mlir(
145   func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
146     %cst42 = arith.constant 42.0 : f32
147     %cst5 = arith.constant 5 : index
148     memref.store %cst42, %arg0[%cst5] : memref<?xf32>
149     return
150   }
151   )mlir";
152   DialectRegistry registry;
153   registerAllDialects(registry);
154   registerLLVMDialectTranslation(registry);
155   MLIRContext context(registry);
156   auto module = parseSourceString(moduleStr, &context);
157   ASSERT_TRUE(!!module);
158   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
159   auto jitOrError = ExecutionEngine::create(*module);
160   ASSERT_TRUE(!!jitOrError);
161   auto jit = std::move(jitOrError.get());
162 
163   llvm::Error error = jit->invoke("one_ranked", &*A);
164   ASSERT_TRUE(!error);
165   count = 1;
166   for (float &elt : *A) {
167     if (count == 6)
168       EXPECT_EQ(elt, 42.);
169     else
170       EXPECT_EQ(elt, count);
171     count++;
172   }
173 }
174 
175 TEST(NativeMemRefJit, BasicMemref) {
176   constexpr int K = 3;
177   constexpr int M = 7;
178   // Prepare arguments beforehand.
179   auto init = [=](float &elt, ArrayRef<int64_t> indices) {
180     assert(indices.size() == 2);
181     elt = M * indices[0] + indices[1];
182   };
183   int64_t shape[] = {K, M};
184   int64_t shapeAlloc[] = {K + 1, M + 1};
185   OwningMemRef<float, 2> A(shape, shapeAlloc, init);
186   ASSERT_EQ(A->sizes[0], K);
187   ASSERT_EQ(A->sizes[1], M);
188   ASSERT_EQ(A->strides[0], M + 1);
189   ASSERT_EQ(A->strides[1], 1);
190   for (int i = 0; i < K; ++i) {
191     for (int j = 0; j < M; ++j) {
192       EXPECT_EQ((A[{i, j}]), i * M + j);
193       EXPECT_EQ(&(A[{i, j}]), &((*A)[i][j]));
194     }
195   }
196   std::string moduleStr = R"mlir(
197   func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
198     %x = arith.constant 2 : index
199     %y = arith.constant 1 : index
200     %cst42 = arith.constant 42.0 : f32
201     memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
202     memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
203     return
204   }
205   )mlir";
206   DialectRegistry registry;
207   registerAllDialects(registry);
208   registerLLVMDialectTranslation(registry);
209   MLIRContext context(registry);
210   OwningModuleRef module = parseSourceString(moduleStr, &context);
211   ASSERT_TRUE(!!module);
212   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
213   auto jitOrError = ExecutionEngine::create(*module);
214   ASSERT_TRUE(!!jitOrError);
215   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
216 
217   llvm::Error error = jit->invoke("rank2_memref", &*A, &*A);
218   ASSERT_TRUE(!error);
219   EXPECT_EQ(((*A)[1][2]), 42.);
220   EXPECT_EQ((A[{2, 1}]), 42.);
221 }
222 
223 // A helper function that will be called from the JIT
224 static void memref_multiply(::StridedMemRefType<float, 2> *memref,
225                             int32_t coefficient) {
226   for (float &elt : *memref)
227     elt *= coefficient;
228 }
229 
230 TEST(NativeMemRefJit, JITCallback) {
231   constexpr int K = 2;
232   constexpr int M = 2;
233   int64_t shape[] = {K, M};
234   int64_t shapeAlloc[] = {K + 1, M + 1};
235   OwningMemRef<float, 2> A(shape, shapeAlloc);
236   int count = 1;
237   for (float &elt : *A)
238     elt = count++;
239 
240   std::string moduleStr = R"mlir(
241   func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32)  attributes { llvm.emit_c_interface }
242   func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
243     %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
244     call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
245     return
246   }
247   )mlir";
248   DialectRegistry registry;
249   registerAllDialects(registry);
250   registerLLVMDialectTranslation(registry);
251   MLIRContext context(registry);
252   auto module = parseSourceString(moduleStr, &context);
253   ASSERT_TRUE(!!module);
254   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
255   auto jitOrError = ExecutionEngine::create(*module);
256   ASSERT_TRUE(!!jitOrError);
257   auto jit = std::move(jitOrError.get());
258   // Define any extra symbols so they're available at runtime.
259   jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
260     llvm::orc::SymbolMap symbolMap;
261     symbolMap[interner("_mlir_ciface_callback")] =
262         llvm::JITEvaluatedSymbol::fromPointer(memref_multiply);
263     return symbolMap;
264   });
265 
266   int32_t coefficient = 3.;
267   llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient);
268   ASSERT_TRUE(!error);
269   count = 1;
270   for (float elt : *A)
271     ASSERT_EQ(elt, coefficient * count++);
272 }
273 
274 #endif // _WIN32
275