1 //===- Invoke.cpp ------------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
11 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
12 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
13 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
14 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 #include "mlir/ExecutionEngine/ExecutionEngine.h"
20 #include "mlir/ExecutionEngine/MemRefUtils.h"
21 #include "mlir/ExecutionEngine/RunnerUtils.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/InitAllDialects.h"
24 #include "mlir/Parser/Parser.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
27 #include "mlir/Target/LLVMIR/Export.h"
28 #include "llvm/Support/TargetSelect.h"
29 #include "llvm/Support/raw_ostream.h"
30 
31 #include "gmock/gmock.h"
32 
33 using namespace mlir;
34 
35 // The JIT isn't supported on Windows at that time
36 #ifndef _WIN32
37 
38 static struct LLVMInitializer {
39   LLVMInitializer() {
40     llvm::InitializeNativeTarget();
41     llvm::InitializeNativeTargetAsmPrinter();
42   }
43 } initializer;
44 
45 /// Simple conversion pipeline for the purpose of testing sources written in
46 /// dialects lowering to LLVM Dialect.
47 static LogicalResult lowerToLLVMDialect(ModuleOp module) {
48   PassManager pm(module.getContext());
49   pm.addPass(mlir::createMemRefToLLVMPass());
50   pm.addNestedPass<func::FuncOp>(
51       mlir::arith::createConvertArithmeticToLLVMPass());
52   pm.addPass(mlir::createConvertFuncToLLVMPass());
53   pm.addPass(mlir::createReconcileUnrealizedCastsPass());
54   return pm.run(module);
55 }
56 
57 TEST(MLIRExecutionEngine, AddInteger) {
58   std::string moduleStr = R"mlir(
59   func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
60     %res = arith.addi %arg0, %arg0 : i32
61     return %res : i32
62   }
63   )mlir";
64   DialectRegistry registry;
65   registerAllDialects(registry);
66   registerLLVMDialectTranslation(registry);
67   MLIRContext context(registry);
68   OwningOpRef<ModuleOp> module =
69       parseSourceString<ModuleOp>(moduleStr, &context);
70   ASSERT_TRUE(!!module);
71   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
72   auto jitOrError = ExecutionEngine::create(*module);
73   ASSERT_TRUE(!!jitOrError);
74   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
75   // The result of the function must be passed as output argument.
76   int result = 0;
77   llvm::Error error =
78       jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
79   ASSERT_TRUE(!error);
80   ASSERT_EQ(result, 42 + 42);
81 }
82 
83 TEST(MLIRExecutionEngine, SubtractFloat) {
84   std::string moduleStr = R"mlir(
85   func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
86     %res = arith.subf %arg0, %arg1 : f32
87     return %res : f32
88   }
89   )mlir";
90   DialectRegistry registry;
91   registerAllDialects(registry);
92   registerLLVMDialectTranslation(registry);
93   MLIRContext context(registry);
94   OwningOpRef<ModuleOp> module =
95       parseSourceString<ModuleOp>(moduleStr, &context);
96   ASSERT_TRUE(!!module);
97   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
98   auto jitOrError = ExecutionEngine::create(*module);
99   ASSERT_TRUE(!!jitOrError);
100   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
101   // The result of the function must be passed as output argument.
102   float result = -1;
103   llvm::Error error =
104       jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
105   ASSERT_TRUE(!error);
106   ASSERT_EQ(result, 42.f);
107 }
108 
109 TEST(NativeMemRefJit, ZeroRankMemref) {
110   OwningMemRef<float, 0> a({});
111   a[{}] = 42.;
112   ASSERT_EQ(*a->data, 42);
113   a[{}] = 0;
114   std::string moduleStr = R"mlir(
115   func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
116     %cst42 = arith.constant 42.0 : f32
117     memref.store %cst42, %arg0[] : memref<f32>
118     return
119   }
120   )mlir";
121   DialectRegistry registry;
122   registerAllDialects(registry);
123   registerLLVMDialectTranslation(registry);
124   MLIRContext context(registry);
125   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
126   ASSERT_TRUE(!!module);
127   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
128   auto jitOrError = ExecutionEngine::create(*module);
129   ASSERT_TRUE(!!jitOrError);
130   auto jit = std::move(jitOrError.get());
131 
132   llvm::Error error = jit->invoke("zero_ranked", &*a);
133   ASSERT_TRUE(!error);
134   EXPECT_EQ((a[{}]), 42.);
135   for (float &elt : *a)
136     EXPECT_EQ(&elt, &(a[{}]));
137 }
138 
139 TEST(NativeMemRefJit, RankOneMemref) {
140   int64_t shape[] = {9};
141   OwningMemRef<float, 1> a(shape);
142   int count = 1;
143   for (float &elt : *a) {
144     EXPECT_EQ(&elt, &(a[{count - 1}]));
145     elt = count++;
146   }
147 
148   std::string moduleStr = R"mlir(
149   func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
150     %cst42 = arith.constant 42.0 : f32
151     %cst5 = arith.constant 5 : index
152     memref.store %cst42, %arg0[%cst5] : memref<?xf32>
153     return
154   }
155   )mlir";
156   DialectRegistry registry;
157   registerAllDialects(registry);
158   registerLLVMDialectTranslation(registry);
159   MLIRContext context(registry);
160   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
161   ASSERT_TRUE(!!module);
162   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
163   auto jitOrError = ExecutionEngine::create(*module);
164   ASSERT_TRUE(!!jitOrError);
165   auto jit = std::move(jitOrError.get());
166 
167   llvm::Error error = jit->invoke("one_ranked", &*a);
168   ASSERT_TRUE(!error);
169   count = 1;
170   for (float &elt : *a) {
171     if (count == 6)
172       EXPECT_EQ(elt, 42.);
173     else
174       EXPECT_EQ(elt, count);
175     count++;
176   }
177 }
178 
179 TEST(NativeMemRefJit, BasicMemref) {
180   constexpr int k = 3;
181   constexpr int m = 7;
182   // Prepare arguments beforehand.
183   auto init = [=](float &elt, ArrayRef<int64_t> indices) {
184     assert(indices.size() == 2);
185     elt = m * indices[0] + indices[1];
186   };
187   int64_t shape[] = {k, m};
188   int64_t shapeAlloc[] = {k + 1, m + 1};
189   OwningMemRef<float, 2> a(shape, shapeAlloc, init);
190   ASSERT_EQ(a->sizes[0], k);
191   ASSERT_EQ(a->sizes[1], m);
192   ASSERT_EQ(a->strides[0], m + 1);
193   ASSERT_EQ(a->strides[1], 1);
194   for (int i = 0; i < k; ++i) {
195     for (int j = 0; j < m; ++j) {
196       EXPECT_EQ((a[{i, j}]), i * m + j);
197       EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
198     }
199   }
200   std::string moduleStr = R"mlir(
201   func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
202     %x = arith.constant 2 : index
203     %y = arith.constant 1 : index
204     %cst42 = arith.constant 42.0 : f32
205     memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
206     memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
207     return
208   }
209   )mlir";
210   DialectRegistry registry;
211   registerAllDialects(registry);
212   registerLLVMDialectTranslation(registry);
213   MLIRContext context(registry);
214   OwningOpRef<ModuleOp> module =
215       parseSourceString<ModuleOp>(moduleStr, &context);
216   ASSERT_TRUE(!!module);
217   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
218   auto jitOrError = ExecutionEngine::create(*module);
219   ASSERT_TRUE(!!jitOrError);
220   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
221 
222   llvm::Error error = jit->invoke("rank2_memref", &*a, &*a);
223   ASSERT_TRUE(!error);
224   EXPECT_EQ(((*a)[1][2]), 42.);
225   EXPECT_EQ((a[{2, 1}]), 42.);
226 }
227 
228 // A helper function that will be called from the JIT
229 static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
230                            int32_t coefficient) {
231   for (float &elt : *memref)
232     elt *= coefficient;
233 }
234 
235 // MSAN does not work with JIT.
236 #if __has_feature(memory_sanitizer)
237 #define MAYBE_JITCallback DISABLED_JITCallback
238 #else
239 #define MAYBE_JITCallback JITCallback
240 #endif
241 TEST(NativeMemRefJit, MAYBE_JITCallback) {
242   constexpr int k = 2;
243   constexpr int m = 2;
244   int64_t shape[] = {k, m};
245   int64_t shapeAlloc[] = {k + 1, m + 1};
246   OwningMemRef<float, 2> a(shape, shapeAlloc);
247   int count = 1;
248   for (float &elt : *a)
249     elt = count++;
250 
251   std::string moduleStr = R"mlir(
252   func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32)  attributes { llvm.emit_c_interface }
253   func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
254     %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
255     call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
256     return
257   }
258   )mlir";
259   DialectRegistry registry;
260   registerAllDialects(registry);
261   registerLLVMDialectTranslation(registry);
262   MLIRContext context(registry);
263   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
264   ASSERT_TRUE(!!module);
265   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
266   auto jitOrError = ExecutionEngine::create(*module);
267   ASSERT_TRUE(!!jitOrError);
268   auto jit = std::move(jitOrError.get());
269   // Define any extra symbols so they're available at runtime.
270   jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
271     llvm::orc::SymbolMap symbolMap;
272     symbolMap[interner("_mlir_ciface_callback")] =
273         llvm::JITEvaluatedSymbol::fromPointer(memrefMultiply);
274     return symbolMap;
275   });
276 
277   int32_t coefficient = 3.;
278   llvm::Error error = jit->invoke("caller_for_callback", &*a, coefficient);
279   ASSERT_TRUE(!error);
280   count = 1;
281   for (float elt : *a)
282     ASSERT_EQ(elt, coefficient * count++);
283 }
284 
285 #endif // _WIN32
286