1 //===- Invoke.cpp ------------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
11 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
12 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
13 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
14 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 #include "mlir/ExecutionEngine/ExecutionEngine.h"
20 #include "mlir/ExecutionEngine/MemRefUtils.h"
21 #include "mlir/ExecutionEngine/RunnerUtils.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/InitAllDialects.h"
24 #include "mlir/Parser/Parser.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
27 #include "mlir/Target/LLVMIR/Export.h"
28 #include "llvm/Support/TargetSelect.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 #include "gmock/gmock.h"
32 
33 // SPARC currently lacks JIT support.
34 #ifdef __sparc__
35 #define SKIP_WITHOUT_JIT(x) DISABLED_##x
36 #else
37 #define SKIP_WITHOUT_JIT(x) x
38 #endif
39 
40 using namespace mlir;
41 
42 // The JIT isn't supported on Windows at that time
43 #ifndef _WIN32
44 
45 static struct LLVMInitializer {
LLVMInitializerLLVMInitializer46   LLVMInitializer() {
47     llvm::InitializeNativeTarget();
48     llvm::InitializeNativeTargetAsmPrinter();
49   }
50 } initializer;
51 
52 /// Simple conversion pipeline for the purpose of testing sources written in
53 /// dialects lowering to LLVM Dialect.
lowerToLLVMDialect(ModuleOp module)54 static LogicalResult lowerToLLVMDialect(ModuleOp module) {
55   PassManager pm(module.getContext());
56   pm.addPass(mlir::createMemRefToLLVMPass());
57   pm.addNestedPass<func::FuncOp>(
58       mlir::arith::createConvertArithmeticToLLVMPass());
59   pm.addPass(mlir::createConvertFuncToLLVMPass());
60   pm.addPass(mlir::createReconcileUnrealizedCastsPass());
61   return pm.run(module);
62 }
63 
TEST(MLIRExecutionEngine,SKIP_WITHOUT_JIT (AddInteger))64 TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(AddInteger)) {
65   std::string moduleStr = R"mlir(
66   func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
67     %res = arith.addi %arg0, %arg0 : i32
68     return %res : i32
69   }
70   )mlir";
71   DialectRegistry registry;
72   registerAllDialects(registry);
73   registerLLVMDialectTranslation(registry);
74   MLIRContext context(registry);
75   OwningOpRef<ModuleOp> module =
76       parseSourceString<ModuleOp>(moduleStr, &context);
77   ASSERT_TRUE(!!module);
78   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
79   auto jitOrError = ExecutionEngine::create(*module);
80   ASSERT_TRUE(!!jitOrError);
81   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
82   // The result of the function must be passed as output argument.
83   int result = 0;
84   llvm::Error error =
85       jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
86   ASSERT_TRUE(!error);
87   ASSERT_EQ(result, 42 + 42);
88 }
89 
TEST(MLIRExecutionEngine,SKIP_WITHOUT_JIT (SubtractFloat))90 TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(SubtractFloat)) {
91   std::string moduleStr = R"mlir(
92   func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
93     %res = arith.subf %arg0, %arg1 : f32
94     return %res : f32
95   }
96   )mlir";
97   DialectRegistry registry;
98   registerAllDialects(registry);
99   registerLLVMDialectTranslation(registry);
100   MLIRContext context(registry);
101   OwningOpRef<ModuleOp> module =
102       parseSourceString<ModuleOp>(moduleStr, &context);
103   ASSERT_TRUE(!!module);
104   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
105   auto jitOrError = ExecutionEngine::create(*module);
106   ASSERT_TRUE(!!jitOrError);
107   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
108   // The result of the function must be passed as output argument.
109   float result = -1;
110   llvm::Error error =
111       jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
112   ASSERT_TRUE(!error);
113   ASSERT_EQ(result, 42.f);
114 }
115 
TEST(NativeMemRefJit,SKIP_WITHOUT_JIT (ZeroRankMemref))116 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(ZeroRankMemref)) {
117   OwningMemRef<float, 0> a({});
118   a[{}] = 42.;
119   ASSERT_EQ(*a->data, 42);
120   a[{}] = 0;
121   std::string moduleStr = R"mlir(
122   func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
123     %cst42 = arith.constant 42.0 : f32
124     memref.store %cst42, %arg0[] : memref<f32>
125     return
126   }
127   )mlir";
128   DialectRegistry registry;
129   registerAllDialects(registry);
130   registerLLVMDialectTranslation(registry);
131   MLIRContext context(registry);
132   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
133   ASSERT_TRUE(!!module);
134   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
135   auto jitOrError = ExecutionEngine::create(*module);
136   ASSERT_TRUE(!!jitOrError);
137   auto jit = std::move(jitOrError.get());
138 
139   llvm::Error error = jit->invoke("zero_ranked", &*a);
140   ASSERT_TRUE(!error);
141   EXPECT_EQ((a[{}]), 42.);
142   for (float &elt : *a)
143     EXPECT_EQ(&elt, &(a[{}]));
144 }
145 
TEST(NativeMemRefJit,SKIP_WITHOUT_JIT (RankOneMemref))146 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(RankOneMemref)) {
147   int64_t shape[] = {9};
148   OwningMemRef<float, 1> a(shape);
149   int count = 1;
150   for (float &elt : *a) {
151     EXPECT_EQ(&elt, &(a[{count - 1}]));
152     elt = count++;
153   }
154 
155   std::string moduleStr = R"mlir(
156   func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
157     %cst42 = arith.constant 42.0 : f32
158     %cst5 = arith.constant 5 : index
159     memref.store %cst42, %arg0[%cst5] : memref<?xf32>
160     return
161   }
162   )mlir";
163   DialectRegistry registry;
164   registerAllDialects(registry);
165   registerLLVMDialectTranslation(registry);
166   MLIRContext context(registry);
167   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
168   ASSERT_TRUE(!!module);
169   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
170   auto jitOrError = ExecutionEngine::create(*module);
171   ASSERT_TRUE(!!jitOrError);
172   auto jit = std::move(jitOrError.get());
173 
174   llvm::Error error = jit->invoke("one_ranked", &*a);
175   ASSERT_TRUE(!error);
176   count = 1;
177   for (float &elt : *a) {
178     if (count == 6)
179       EXPECT_EQ(elt, 42.);
180     else
181       EXPECT_EQ(elt, count);
182     count++;
183   }
184 }
185 
TEST(NativeMemRefJit,SKIP_WITHOUT_JIT (BasicMemref))186 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
187   constexpr int k = 3;
188   constexpr int m = 7;
189   // Prepare arguments beforehand.
190   auto init = [=](float &elt, ArrayRef<int64_t> indices) {
191     assert(indices.size() == 2);
192     elt = m * indices[0] + indices[1];
193   };
194   int64_t shape[] = {k, m};
195   int64_t shapeAlloc[] = {k + 1, m + 1};
196   OwningMemRef<float, 2> a(shape, shapeAlloc, init);
197   ASSERT_EQ(a->sizes[0], k);
198   ASSERT_EQ(a->sizes[1], m);
199   ASSERT_EQ(a->strides[0], m + 1);
200   ASSERT_EQ(a->strides[1], 1);
201   for (int i = 0; i < k; ++i) {
202     for (int j = 0; j < m; ++j) {
203       EXPECT_EQ((a[{i, j}]), i * m + j);
204       EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
205     }
206   }
207   std::string moduleStr = R"mlir(
208   func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
209     %x = arith.constant 2 : index
210     %y = arith.constant 1 : index
211     %cst42 = arith.constant 42.0 : f32
212     memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
213     memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
214     return
215   }
216   )mlir";
217   DialectRegistry registry;
218   registerAllDialects(registry);
219   registerLLVMDialectTranslation(registry);
220   MLIRContext context(registry);
221   OwningOpRef<ModuleOp> module =
222       parseSourceString<ModuleOp>(moduleStr, &context);
223   ASSERT_TRUE(!!module);
224   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
225   auto jitOrError = ExecutionEngine::create(*module);
226   ASSERT_TRUE(!!jitOrError);
227   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
228 
229   llvm::Error error = jit->invoke("rank2_memref", &*a, &*a);
230   ASSERT_TRUE(!error);
231   EXPECT_EQ(((*a)[1][2]), 42.);
232   EXPECT_EQ((a[{2, 1}]), 42.);
233 }
234 
235 // A helper function that will be called from the JIT
memrefMultiply(::StridedMemRefType<float,2> * memref,int32_t coefficient)236 static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
237                            int32_t coefficient) {
238   for (float &elt : *memref)
239     elt *= coefficient;
240 }
241 
242 // MSAN does not work with JIT.
243 #if __has_feature(memory_sanitizer)
244 #define MAYBE_JITCallback DISABLED_JITCallback
245 #else
246 #define MAYBE_JITCallback SKIP_WITHOUT_JIT(JITCallback)
247 #endif
TEST(NativeMemRefJit,MAYBE_JITCallback)248 TEST(NativeMemRefJit, MAYBE_JITCallback) {
249   constexpr int k = 2;
250   constexpr int m = 2;
251   int64_t shape[] = {k, m};
252   int64_t shapeAlloc[] = {k + 1, m + 1};
253   OwningMemRef<float, 2> a(shape, shapeAlloc);
254   int count = 1;
255   for (float &elt : *a)
256     elt = count++;
257 
258   std::string moduleStr = R"mlir(
259   func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32)  attributes { llvm.emit_c_interface }
260   func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
261     %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
262     call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
263     return
264   }
265   )mlir";
266   DialectRegistry registry;
267   registerAllDialects(registry);
268   registerLLVMDialectTranslation(registry);
269   MLIRContext context(registry);
270   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
271   ASSERT_TRUE(!!module);
272   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
273   auto jitOrError = ExecutionEngine::create(*module);
274   ASSERT_TRUE(!!jitOrError);
275   auto jit = std::move(jitOrError.get());
276   // Define any extra symbols so they're available at runtime.
277   jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
278     llvm::orc::SymbolMap symbolMap;
279     symbolMap[interner("_mlir_ciface_callback")] =
280         llvm::JITEvaluatedSymbol::fromPointer(memrefMultiply);
281     return symbolMap;
282   });
283 
284   int32_t coefficient = 3.;
285   llvm::Error error = jit->invoke("caller_for_callback", &*a, coefficient);
286   ASSERT_TRUE(!error);
287   count = 1;
288   for (float elt : *a)
289     ASSERT_EQ(elt, coefficient * count++);
290 }
291 
292 #endif // _WIN32
293