1 //===- Invoke.cpp ------------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
11 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
12 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
13 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
14 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/Dialect/Linalg/Passes.h"
18 #include "mlir/ExecutionEngine/CRunnerUtils.h"
19 #include "mlir/ExecutionEngine/ExecutionEngine.h"
20 #include "mlir/ExecutionEngine/MemRefUtils.h"
21 #include "mlir/ExecutionEngine/RunnerUtils.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/InitAllDialects.h"
24 #include "mlir/Parser/Parser.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
27 #include "mlir/Target/LLVMIR/Export.h"
28 #include "llvm/Support/TargetSelect.h"
29 #include "llvm/Support/raw_ostream.h"
30
31 #include "gmock/gmock.h"
32
33 // SPARC currently lacks JIT support.
34 #ifdef __sparc__
35 #define SKIP_WITHOUT_JIT(x) DISABLED_##x
36 #else
37 #define SKIP_WITHOUT_JIT(x) x
38 #endif
39
40 using namespace mlir;
41
42 // The JIT isn't supported on Windows at that time
43 #ifndef _WIN32
44
45 static struct LLVMInitializer {
LLVMInitializerLLVMInitializer46 LLVMInitializer() {
47 llvm::InitializeNativeTarget();
48 llvm::InitializeNativeTargetAsmPrinter();
49 }
50 } initializer;
51
52 /// Simple conversion pipeline for the purpose of testing sources written in
53 /// dialects lowering to LLVM Dialect.
lowerToLLVMDialect(ModuleOp module)54 static LogicalResult lowerToLLVMDialect(ModuleOp module) {
55 PassManager pm(module.getContext());
56 pm.addPass(mlir::createMemRefToLLVMPass());
57 pm.addNestedPass<func::FuncOp>(
58 mlir::arith::createConvertArithmeticToLLVMPass());
59 pm.addPass(mlir::createConvertFuncToLLVMPass());
60 pm.addPass(mlir::createReconcileUnrealizedCastsPass());
61 return pm.run(module);
62 }
63
TEST(MLIRExecutionEngine,SKIP_WITHOUT_JIT (AddInteger))64 TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(AddInteger)) {
65 std::string moduleStr = R"mlir(
66 func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
67 %res = arith.addi %arg0, %arg0 : i32
68 return %res : i32
69 }
70 )mlir";
71 DialectRegistry registry;
72 registerAllDialects(registry);
73 registerLLVMDialectTranslation(registry);
74 MLIRContext context(registry);
75 OwningOpRef<ModuleOp> module =
76 parseSourceString<ModuleOp>(moduleStr, &context);
77 ASSERT_TRUE(!!module);
78 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
79 auto jitOrError = ExecutionEngine::create(*module);
80 ASSERT_TRUE(!!jitOrError);
81 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
82 // The result of the function must be passed as output argument.
83 int result = 0;
84 llvm::Error error =
85 jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
86 ASSERT_TRUE(!error);
87 ASSERT_EQ(result, 42 + 42);
88 }
89
TEST(MLIRExecutionEngine,SKIP_WITHOUT_JIT (SubtractFloat))90 TEST(MLIRExecutionEngine, SKIP_WITHOUT_JIT(SubtractFloat)) {
91 std::string moduleStr = R"mlir(
92 func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
93 %res = arith.subf %arg0, %arg1 : f32
94 return %res : f32
95 }
96 )mlir";
97 DialectRegistry registry;
98 registerAllDialects(registry);
99 registerLLVMDialectTranslation(registry);
100 MLIRContext context(registry);
101 OwningOpRef<ModuleOp> module =
102 parseSourceString<ModuleOp>(moduleStr, &context);
103 ASSERT_TRUE(!!module);
104 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
105 auto jitOrError = ExecutionEngine::create(*module);
106 ASSERT_TRUE(!!jitOrError);
107 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
108 // The result of the function must be passed as output argument.
109 float result = -1;
110 llvm::Error error =
111 jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
112 ASSERT_TRUE(!error);
113 ASSERT_EQ(result, 42.f);
114 }
115
TEST(NativeMemRefJit,SKIP_WITHOUT_JIT (ZeroRankMemref))116 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(ZeroRankMemref)) {
117 OwningMemRef<float, 0> a({});
118 a[{}] = 42.;
119 ASSERT_EQ(*a->data, 42);
120 a[{}] = 0;
121 std::string moduleStr = R"mlir(
122 func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
123 %cst42 = arith.constant 42.0 : f32
124 memref.store %cst42, %arg0[] : memref<f32>
125 return
126 }
127 )mlir";
128 DialectRegistry registry;
129 registerAllDialects(registry);
130 registerLLVMDialectTranslation(registry);
131 MLIRContext context(registry);
132 auto module = parseSourceString<ModuleOp>(moduleStr, &context);
133 ASSERT_TRUE(!!module);
134 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
135 auto jitOrError = ExecutionEngine::create(*module);
136 ASSERT_TRUE(!!jitOrError);
137 auto jit = std::move(jitOrError.get());
138
139 llvm::Error error = jit->invoke("zero_ranked", &*a);
140 ASSERT_TRUE(!error);
141 EXPECT_EQ((a[{}]), 42.);
142 for (float &elt : *a)
143 EXPECT_EQ(&elt, &(a[{}]));
144 }
145
TEST(NativeMemRefJit,SKIP_WITHOUT_JIT (RankOneMemref))146 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(RankOneMemref)) {
147 int64_t shape[] = {9};
148 OwningMemRef<float, 1> a(shape);
149 int count = 1;
150 for (float &elt : *a) {
151 EXPECT_EQ(&elt, &(a[{count - 1}]));
152 elt = count++;
153 }
154
155 std::string moduleStr = R"mlir(
156 func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
157 %cst42 = arith.constant 42.0 : f32
158 %cst5 = arith.constant 5 : index
159 memref.store %cst42, %arg0[%cst5] : memref<?xf32>
160 return
161 }
162 )mlir";
163 DialectRegistry registry;
164 registerAllDialects(registry);
165 registerLLVMDialectTranslation(registry);
166 MLIRContext context(registry);
167 auto module = parseSourceString<ModuleOp>(moduleStr, &context);
168 ASSERT_TRUE(!!module);
169 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
170 auto jitOrError = ExecutionEngine::create(*module);
171 ASSERT_TRUE(!!jitOrError);
172 auto jit = std::move(jitOrError.get());
173
174 llvm::Error error = jit->invoke("one_ranked", &*a);
175 ASSERT_TRUE(!error);
176 count = 1;
177 for (float &elt : *a) {
178 if (count == 6)
179 EXPECT_EQ(elt, 42.);
180 else
181 EXPECT_EQ(elt, count);
182 count++;
183 }
184 }
185
TEST(NativeMemRefJit,SKIP_WITHOUT_JIT (BasicMemref))186 TEST(NativeMemRefJit, SKIP_WITHOUT_JIT(BasicMemref)) {
187 constexpr int k = 3;
188 constexpr int m = 7;
189 // Prepare arguments beforehand.
190 auto init = [=](float &elt, ArrayRef<int64_t> indices) {
191 assert(indices.size() == 2);
192 elt = m * indices[0] + indices[1];
193 };
194 int64_t shape[] = {k, m};
195 int64_t shapeAlloc[] = {k + 1, m + 1};
196 OwningMemRef<float, 2> a(shape, shapeAlloc, init);
197 ASSERT_EQ(a->sizes[0], k);
198 ASSERT_EQ(a->sizes[1], m);
199 ASSERT_EQ(a->strides[0], m + 1);
200 ASSERT_EQ(a->strides[1], 1);
201 for (int i = 0; i < k; ++i) {
202 for (int j = 0; j < m; ++j) {
203 EXPECT_EQ((a[{i, j}]), i * m + j);
204 EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
205 }
206 }
207 std::string moduleStr = R"mlir(
208 func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
209 %x = arith.constant 2 : index
210 %y = arith.constant 1 : index
211 %cst42 = arith.constant 42.0 : f32
212 memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
213 memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
214 return
215 }
216 )mlir";
217 DialectRegistry registry;
218 registerAllDialects(registry);
219 registerLLVMDialectTranslation(registry);
220 MLIRContext context(registry);
221 OwningOpRef<ModuleOp> module =
222 parseSourceString<ModuleOp>(moduleStr, &context);
223 ASSERT_TRUE(!!module);
224 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
225 auto jitOrError = ExecutionEngine::create(*module);
226 ASSERT_TRUE(!!jitOrError);
227 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
228
229 llvm::Error error = jit->invoke("rank2_memref", &*a, &*a);
230 ASSERT_TRUE(!error);
231 EXPECT_EQ(((*a)[1][2]), 42.);
232 EXPECT_EQ((a[{2, 1}]), 42.);
233 }
234
235 // A helper function that will be called from the JIT
memrefMultiply(::StridedMemRefType<float,2> * memref,int32_t coefficient)236 static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
237 int32_t coefficient) {
238 for (float &elt : *memref)
239 elt *= coefficient;
240 }
241
242 // MSAN does not work with JIT.
243 #if __has_feature(memory_sanitizer)
244 #define MAYBE_JITCallback DISABLED_JITCallback
245 #else
246 #define MAYBE_JITCallback SKIP_WITHOUT_JIT(JITCallback)
247 #endif
TEST(NativeMemRefJit,MAYBE_JITCallback)248 TEST(NativeMemRefJit, MAYBE_JITCallback) {
249 constexpr int k = 2;
250 constexpr int m = 2;
251 int64_t shape[] = {k, m};
252 int64_t shapeAlloc[] = {k + 1, m + 1};
253 OwningMemRef<float, 2> a(shape, shapeAlloc);
254 int count = 1;
255 for (float &elt : *a)
256 elt = count++;
257
258 std::string moduleStr = R"mlir(
259 func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface }
260 func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
261 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
262 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
263 return
264 }
265 )mlir";
266 DialectRegistry registry;
267 registerAllDialects(registry);
268 registerLLVMDialectTranslation(registry);
269 MLIRContext context(registry);
270 auto module = parseSourceString<ModuleOp>(moduleStr, &context);
271 ASSERT_TRUE(!!module);
272 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
273 auto jitOrError = ExecutionEngine::create(*module);
274 ASSERT_TRUE(!!jitOrError);
275 auto jit = std::move(jitOrError.get());
276 // Define any extra symbols so they're available at runtime.
277 jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
278 llvm::orc::SymbolMap symbolMap;
279 symbolMap[interner("_mlir_ciface_callback")] =
280 llvm::JITEvaluatedSymbol::fromPointer(memrefMultiply);
281 return symbolMap;
282 });
283
284 int32_t coefficient = 3.;
285 llvm::Error error = jit->invoke("caller_for_callback", &*a, coefficient);
286 ASSERT_TRUE(!error);
287 count = 1;
288 for (float elt : *a)
289 ASSERT_EQ(elt, coefficient * count++);
290 }
291
292 #endif // _WIN32
293