1 //===- Invoke.cpp ------------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
10 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
11 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
12 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
13 #include "mlir/Dialect/Linalg/Passes.h"
14 #include "mlir/ExecutionEngine/CRunnerUtils.h"
15 #include "mlir/ExecutionEngine/ExecutionEngine.h"
16 #include "mlir/ExecutionEngine/MemRefUtils.h"
17 #include "mlir/ExecutionEngine/RunnerUtils.h"
18 #include "mlir/IR/MLIRContext.h"
19 #include "mlir/InitAllDialects.h"
20 #include "mlir/Parser.h"
21 #include "mlir/Pass/PassManager.h"
22 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
23 #include "mlir/Target/LLVMIR/Export.h"
24 #include "llvm/Support/TargetSelect.h"
25 #include "llvm/Support/raw_ostream.h"
26 
27 #include "gmock/gmock.h"
28 
29 using namespace mlir;
30 
31 static struct LLVMInitializer {
32   LLVMInitializer() {
33     llvm::InitializeNativeTarget();
34     llvm::InitializeNativeTargetAsmPrinter();
35   }
36 } initializer;
37 
38 /// Simple conversion pipeline for the purpose of testing sources written in
39 /// dialects lowering to LLVM Dialect.
40 static LogicalResult lowerToLLVMDialect(ModuleOp module) {
41   PassManager pm(module.getContext());
42   pm.addPass(mlir::createLowerToLLVMPass());
43   return pm.run(module);
44 }
45 
46 // The JIT isn't supported on Windows at that time
47 #ifndef _WIN32
48 
49 TEST(MLIRExecutionEngine, AddInteger) {
50   std::string moduleStr = R"mlir(
51   func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
52     %res = std.addi %arg0, %arg0 : i32
53     return %res : i32
54   }
55   )mlir";
56   DialectRegistry registry;
57   registerAllDialects(registry);
58   registerLLVMDialectTranslation(registry);
59   MLIRContext context(registry);
60   OwningModuleRef module = parseSourceString(moduleStr, &context);
61   ASSERT_TRUE(!!module);
62   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
63   auto jitOrError = ExecutionEngine::create(*module);
64   ASSERT_TRUE(!!jitOrError);
65   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
66   // The result of the function must be passed as output argument.
67   int result = 0;
68   llvm::Error error =
69       jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
70   ASSERT_TRUE(!error);
71   ASSERT_EQ(result, 42 + 42);
72 }
73 
74 TEST(MLIRExecutionEngine, SubtractFloat) {
75   std::string moduleStr = R"mlir(
76   func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
77     %res = std.subf %arg0, %arg1 : f32
78     return %res : f32
79   }
80   )mlir";
81   DialectRegistry registry;
82   registerAllDialects(registry);
83   registerLLVMDialectTranslation(registry);
84   MLIRContext context(registry);
85   OwningModuleRef module = parseSourceString(moduleStr, &context);
86   ASSERT_TRUE(!!module);
87   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
88   auto jitOrError = ExecutionEngine::create(*module);
89   ASSERT_TRUE(!!jitOrError);
90   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
91   // The result of the function must be passed as output argument.
92   float result = -1;
93   llvm::Error error =
94       jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
95   ASSERT_TRUE(!error);
96   ASSERT_EQ(result, 42.f);
97 }
98 
99 TEST(NativeMemRefJit, ZeroRankMemref) {
100   OwningMemRef<float, 0> A({});
101   A[{}] = 42.;
102   ASSERT_EQ(*A->data, 42);
103   A[{}] = 0;
104   std::string moduleStr = R"mlir(
105   func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
106     %cst42 = constant 42.0 : f32
107     memref.store %cst42, %arg0[] : memref<f32>
108     return
109   }
110   )mlir";
111   DialectRegistry registry;
112   registerAllDialects(registry);
113   registerLLVMDialectTranslation(registry);
114   MLIRContext context(registry);
115   auto module = parseSourceString(moduleStr, &context);
116   ASSERT_TRUE(!!module);
117   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
118   auto jitOrError = ExecutionEngine::create(*module);
119   ASSERT_TRUE(!!jitOrError);
120   auto jit = std::move(jitOrError.get());
121 
122   llvm::Error error = jit->invoke("zero_ranked", &*A);
123   ASSERT_TRUE(!error);
124   EXPECT_EQ((A[{}]), 42.);
125   for (float &elt : *A)
126     EXPECT_EQ(&elt, &(A[{}]));
127 }
128 
129 TEST(NativeMemRefJit, RankOneMemref) {
130   int64_t shape[] = {9};
131   OwningMemRef<float, 1> A(shape);
132   int count = 1;
133   for (float &elt : *A) {
134     EXPECT_EQ(&elt, &(A[{count - 1}]));
135     elt = count++;
136   }
137 
138   std::string moduleStr = R"mlir(
139   func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
140     %cst42 = constant 42.0 : f32
141     %cst5 = constant 5 : index
142     memref.store %cst42, %arg0[%cst5] : memref<?xf32>
143     return
144   }
145   )mlir";
146   DialectRegistry registry;
147   registerAllDialects(registry);
148   registerLLVMDialectTranslation(registry);
149   MLIRContext context(registry);
150   auto module = parseSourceString(moduleStr, &context);
151   ASSERT_TRUE(!!module);
152   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
153   auto jitOrError = ExecutionEngine::create(*module);
154   ASSERT_TRUE(!!jitOrError);
155   auto jit = std::move(jitOrError.get());
156 
157   llvm::Error error = jit->invoke("one_ranked", &*A);
158   ASSERT_TRUE(!error);
159   count = 1;
160   for (float &elt : *A) {
161     if (count == 6)
162       EXPECT_EQ(elt, 42.);
163     else
164       EXPECT_EQ(elt, count);
165     count++;
166   }
167 }
168 
169 TEST(NativeMemRefJit, BasicMemref) {
170   constexpr int K = 3;
171   constexpr int M = 7;
172   // Prepare arguments beforehand.
173   auto init = [=](float &elt, ArrayRef<int64_t> indices) {
174     assert(indices.size() == 2);
175     elt = M * indices[0] + indices[1];
176   };
177   int64_t shape[] = {K, M};
178   int64_t shapeAlloc[] = {K + 1, M + 1};
179   OwningMemRef<float, 2> A(shape, shapeAlloc, init);
180   ASSERT_EQ(A->sizes[0], K);
181   ASSERT_EQ(A->sizes[1], M);
182   ASSERT_EQ(A->strides[0], M + 1);
183   ASSERT_EQ(A->strides[1], 1);
184   for (int i = 0; i < K; ++i) {
185     for (int j = 0; j < M; ++j) {
186       EXPECT_EQ((A[{i, j}]), i * M + j);
187       EXPECT_EQ(&(A[{i, j}]), &((*A)[i][j]));
188     }
189   }
190   std::string moduleStr = R"mlir(
191   func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
192     %x = constant 2 : index
193     %y = constant 1 : index
194     %cst42 = constant 42.0 : f32
195     memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
196     memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
197     return
198   }
199   )mlir";
200   DialectRegistry registry;
201   registerAllDialects(registry);
202   registerLLVMDialectTranslation(registry);
203   MLIRContext context(registry);
204   OwningModuleRef module = parseSourceString(moduleStr, &context);
205   ASSERT_TRUE(!!module);
206   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
207   auto jitOrError = ExecutionEngine::create(*module);
208   ASSERT_TRUE(!!jitOrError);
209   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
210 
211   llvm::Error error = jit->invoke("rank2_memref", &*A, &*A);
212   ASSERT_TRUE(!error);
213   EXPECT_EQ(((*A)[1][2]), 42.);
214   EXPECT_EQ((A[{2, 1}]), 42.);
215 }
216 
217 // A helper function that will be called from the JIT
218 static void memref_multiply(::StridedMemRefType<float, 2> *memref,
219                             int32_t coefficient) {
220   for (float &elt : *memref)
221     elt *= coefficient;
222 }
223 
224 TEST(NativeMemRefJit, JITCallback) {
225   constexpr int K = 2;
226   constexpr int M = 2;
227   int64_t shape[] = {K, M};
228   int64_t shapeAlloc[] = {K + 1, M + 1};
229   OwningMemRef<float, 2> A(shape, shapeAlloc);
230   int count = 1;
231   for (float &elt : *A)
232     elt = count++;
233 
234   std::string moduleStr = R"mlir(
235   func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32)  attributes { llvm.emit_c_interface }
236   func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
237     %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
238     call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
239     return
240   }
241   )mlir";
242   DialectRegistry registry;
243   registerAllDialects(registry);
244   registerLLVMDialectTranslation(registry);
245   MLIRContext context(registry);
246   auto module = parseSourceString(moduleStr, &context);
247   ASSERT_TRUE(!!module);
248   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
249   auto jitOrError = ExecutionEngine::create(*module);
250   ASSERT_TRUE(!!jitOrError);
251   auto jit = std::move(jitOrError.get());
252   // Define any extra symbols so they're available at runtime.
253   jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
254     llvm::orc::SymbolMap symbolMap;
255     symbolMap[interner("_mlir_ciface_callback")] =
256         llvm::JITEvaluatedSymbol::fromPointer(memref_multiply);
257     return symbolMap;
258   });
259 
260   int32_t coefficient = 3.;
261   llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient);
262   ASSERT_TRUE(!error);
263   count = 1;
264   for (float elt : *A)
265     ASSERT_EQ(elt, coefficient * count++);
266 }
267 
268 #endif // _WIN32
269