1 //===- Invoke.cpp ------------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
11 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
13 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
14 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
15 #include "mlir/Dialect/Linalg/Passes.h"
16 #include "mlir/ExecutionEngine/CRunnerUtils.h"
17 #include "mlir/ExecutionEngine/ExecutionEngine.h"
18 #include "mlir/ExecutionEngine/MemRefUtils.h"
19 #include "mlir/ExecutionEngine/RunnerUtils.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/InitAllDialects.h"
22 #include "mlir/Parser.h"
23 #include "mlir/Pass/PassManager.h"
24 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
25 #include "mlir/Target/LLVMIR/Export.h"
26 #include "llvm/Support/TargetSelect.h"
27 #include "llvm/Support/raw_ostream.h"
28 
29 #include "gmock/gmock.h"
30 
31 using namespace mlir;
32 
33 static struct LLVMInitializer {
34   LLVMInitializer() {
35     llvm::InitializeNativeTarget();
36     llvm::InitializeNativeTargetAsmPrinter();
37   }
38 } initializer;
39 
40 /// Simple conversion pipeline for the purpose of testing sources written in
41 /// dialects lowering to LLVM Dialect.
42 static LogicalResult lowerToLLVMDialect(ModuleOp module) {
43   PassManager pm(module.getContext());
44   pm.addPass(mlir::createMemRefToLLVMPass());
45   pm.addPass(mlir::createLowerToLLVMPass());
46   pm.addPass(mlir::createReconcileUnrealizedCastsPass());
47   return pm.run(module);
48 }
49 
50 // The JIT isn't supported on Windows at that time
51 #ifndef _WIN32
52 
53 TEST(MLIRExecutionEngine, AddInteger) {
54   std::string moduleStr = R"mlir(
55   func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
56     %res = std.addi %arg0, %arg0 : i32
57     return %res : i32
58   }
59   )mlir";
60   DialectRegistry registry;
61   registerAllDialects(registry);
62   registerLLVMDialectTranslation(registry);
63   MLIRContext context(registry);
64   OwningModuleRef module = parseSourceString(moduleStr, &context);
65   ASSERT_TRUE(!!module);
66   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
67   auto jitOrError = ExecutionEngine::create(*module);
68   ASSERT_TRUE(!!jitOrError);
69   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
70   // The result of the function must be passed as output argument.
71   int result = 0;
72   llvm::Error error =
73       jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
74   ASSERT_TRUE(!error);
75   ASSERT_EQ(result, 42 + 42);
76 }
77 
78 TEST(MLIRExecutionEngine, SubtractFloat) {
79   std::string moduleStr = R"mlir(
80   func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
81     %res = std.subf %arg0, %arg1 : f32
82     return %res : f32
83   }
84   )mlir";
85   DialectRegistry registry;
86   registerAllDialects(registry);
87   registerLLVMDialectTranslation(registry);
88   MLIRContext context(registry);
89   OwningModuleRef module = parseSourceString(moduleStr, &context);
90   ASSERT_TRUE(!!module);
91   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
92   auto jitOrError = ExecutionEngine::create(*module);
93   ASSERT_TRUE(!!jitOrError);
94   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
95   // The result of the function must be passed as output argument.
96   float result = -1;
97   llvm::Error error =
98       jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
99   ASSERT_TRUE(!error);
100   ASSERT_EQ(result, 42.f);
101 }
102 
103 TEST(NativeMemRefJit, ZeroRankMemref) {
104   OwningMemRef<float, 0> A({});
105   A[{}] = 42.;
106   ASSERT_EQ(*A->data, 42);
107   A[{}] = 0;
108   std::string moduleStr = R"mlir(
109   func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
110     %cst42 = constant 42.0 : f32
111     memref.store %cst42, %arg0[] : memref<f32>
112     return
113   }
114   )mlir";
115   DialectRegistry registry;
116   registerAllDialects(registry);
117   registerLLVMDialectTranslation(registry);
118   MLIRContext context(registry);
119   auto module = parseSourceString(moduleStr, &context);
120   ASSERT_TRUE(!!module);
121   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
122   auto jitOrError = ExecutionEngine::create(*module);
123   ASSERT_TRUE(!!jitOrError);
124   auto jit = std::move(jitOrError.get());
125 
126   llvm::Error error = jit->invoke("zero_ranked", &*A);
127   ASSERT_TRUE(!error);
128   EXPECT_EQ((A[{}]), 42.);
129   for (float &elt : *A)
130     EXPECT_EQ(&elt, &(A[{}]));
131 }
132 
133 TEST(NativeMemRefJit, RankOneMemref) {
134   int64_t shape[] = {9};
135   OwningMemRef<float, 1> A(shape);
136   int count = 1;
137   for (float &elt : *A) {
138     EXPECT_EQ(&elt, &(A[{count - 1}]));
139     elt = count++;
140   }
141 
142   std::string moduleStr = R"mlir(
143   func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
144     %cst42 = constant 42.0 : f32
145     %cst5 = constant 5 : index
146     memref.store %cst42, %arg0[%cst5] : memref<?xf32>
147     return
148   }
149   )mlir";
150   DialectRegistry registry;
151   registerAllDialects(registry);
152   registerLLVMDialectTranslation(registry);
153   MLIRContext context(registry);
154   auto module = parseSourceString(moduleStr, &context);
155   ASSERT_TRUE(!!module);
156   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
157   auto jitOrError = ExecutionEngine::create(*module);
158   ASSERT_TRUE(!!jitOrError);
159   auto jit = std::move(jitOrError.get());
160 
161   llvm::Error error = jit->invoke("one_ranked", &*A);
162   ASSERT_TRUE(!error);
163   count = 1;
164   for (float &elt : *A) {
165     if (count == 6)
166       EXPECT_EQ(elt, 42.);
167     else
168       EXPECT_EQ(elt, count);
169     count++;
170   }
171 }
172 
173 TEST(NativeMemRefJit, BasicMemref) {
174   constexpr int K = 3;
175   constexpr int M = 7;
176   // Prepare arguments beforehand.
177   auto init = [=](float &elt, ArrayRef<int64_t> indices) {
178     assert(indices.size() == 2);
179     elt = M * indices[0] + indices[1];
180   };
181   int64_t shape[] = {K, M};
182   int64_t shapeAlloc[] = {K + 1, M + 1};
183   OwningMemRef<float, 2> A(shape, shapeAlloc, init);
184   ASSERT_EQ(A->sizes[0], K);
185   ASSERT_EQ(A->sizes[1], M);
186   ASSERT_EQ(A->strides[0], M + 1);
187   ASSERT_EQ(A->strides[1], 1);
188   for (int i = 0; i < K; ++i) {
189     for (int j = 0; j < M; ++j) {
190       EXPECT_EQ((A[{i, j}]), i * M + j);
191       EXPECT_EQ(&(A[{i, j}]), &((*A)[i][j]));
192     }
193   }
194   std::string moduleStr = R"mlir(
195   func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
196     %x = constant 2 : index
197     %y = constant 1 : index
198     %cst42 = constant 42.0 : f32
199     memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
200     memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
201     return
202   }
203   )mlir";
204   DialectRegistry registry;
205   registerAllDialects(registry);
206   registerLLVMDialectTranslation(registry);
207   MLIRContext context(registry);
208   OwningModuleRef module = parseSourceString(moduleStr, &context);
209   ASSERT_TRUE(!!module);
210   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
211   auto jitOrError = ExecutionEngine::create(*module);
212   ASSERT_TRUE(!!jitOrError);
213   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
214 
215   llvm::Error error = jit->invoke("rank2_memref", &*A, &*A);
216   ASSERT_TRUE(!error);
217   EXPECT_EQ(((*A)[1][2]), 42.);
218   EXPECT_EQ((A[{2, 1}]), 42.);
219 }
220 
221 // A helper function that will be called from the JIT
222 static void memref_multiply(::StridedMemRefType<float, 2> *memref,
223                             int32_t coefficient) {
224   for (float &elt : *memref)
225     elt *= coefficient;
226 }
227 
228 TEST(NativeMemRefJit, JITCallback) {
229   constexpr int K = 2;
230   constexpr int M = 2;
231   int64_t shape[] = {K, M};
232   int64_t shapeAlloc[] = {K + 1, M + 1};
233   OwningMemRef<float, 2> A(shape, shapeAlloc);
234   int count = 1;
235   for (float &elt : *A)
236     elt = count++;
237 
238   std::string moduleStr = R"mlir(
239   func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32)  attributes { llvm.emit_c_interface }
240   func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
241     %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
242     call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
243     return
244   }
245   )mlir";
246   DialectRegistry registry;
247   registerAllDialects(registry);
248   registerLLVMDialectTranslation(registry);
249   MLIRContext context(registry);
250   auto module = parseSourceString(moduleStr, &context);
251   ASSERT_TRUE(!!module);
252   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
253   auto jitOrError = ExecutionEngine::create(*module);
254   ASSERT_TRUE(!!jitOrError);
255   auto jit = std::move(jitOrError.get());
256   // Define any extra symbols so they're available at runtime.
257   jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
258     llvm::orc::SymbolMap symbolMap;
259     symbolMap[interner("_mlir_ciface_callback")] =
260         llvm::JITEvaluatedSymbol::fromPointer(memref_multiply);
261     return symbolMap;
262   });
263 
264   int32_t coefficient = 3.;
265   llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient);
266   ASSERT_TRUE(!error);
267   count = 1;
268   for (float elt : *A)
269     ASSERT_EQ(elt, coefficient * count++);
270 }
271 
272 #endif // _WIN32
273