1 //===- Invoke.cpp ------------------------------------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
10 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
11 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h"
12 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
13 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
14 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/ExecutionEngine/CRunnerUtils.h"
18 #include "mlir/ExecutionEngine/ExecutionEngine.h"
19 #include "mlir/ExecutionEngine/MemRefUtils.h"
20 #include "mlir/ExecutionEngine/RunnerUtils.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/InitAllDialects.h"
23 #include "mlir/Parser/Parser.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
26 #include "mlir/Target/LLVMIR/Export.h"
27 #include "llvm/Support/TargetSelect.h"
28 #include "llvm/Support/raw_ostream.h"
29 
30 #include "gmock/gmock.h"
31 
32 using namespace mlir;
33 
34 // The JIT isn't supported on Windows at that time
35 #ifndef _WIN32
36 
37 static struct LLVMInitializer {
38   LLVMInitializer() {
39     llvm::InitializeNativeTarget();
40     llvm::InitializeNativeTargetAsmPrinter();
41   }
42 } initializer;
43 
44 /// Simple conversion pipeline for the purpose of testing sources written in
45 /// dialects lowering to LLVM Dialect.
46 static LogicalResult lowerToLLVMDialect(ModuleOp module) {
47   PassManager pm(module.getContext());
48   pm.addPass(mlir::createMemRefToLLVMPass());
49   pm.addNestedPass<FuncOp>(mlir::arith::createConvertArithmeticToLLVMPass());
50   pm.addPass(mlir::createConvertFuncToLLVMPass());
51   pm.addPass(mlir::createReconcileUnrealizedCastsPass());
52   return pm.run(module);
53 }
54 
55 TEST(MLIRExecutionEngine, AddInteger) {
56   std::string moduleStr = R"mlir(
57   func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } {
58     %res = arith.addi %arg0, %arg0 : i32
59     return %res : i32
60   }
61   )mlir";
62   DialectRegistry registry;
63   registerAllDialects(registry);
64   registerLLVMDialectTranslation(registry);
65   MLIRContext context(registry);
66   OwningOpRef<ModuleOp> module =
67       parseSourceString<ModuleOp>(moduleStr, &context);
68   ASSERT_TRUE(!!module);
69   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
70   auto jitOrError = ExecutionEngine::create(*module);
71   ASSERT_TRUE(!!jitOrError);
72   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
73   // The result of the function must be passed as output argument.
74   int result = 0;
75   llvm::Error error =
76       jit->invoke("foo", 42, ExecutionEngine::Result<int>(result));
77   ASSERT_TRUE(!error);
78   ASSERT_EQ(result, 42 + 42);
79 }
80 
81 TEST(MLIRExecutionEngine, SubtractFloat) {
82   std::string moduleStr = R"mlir(
83   func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } {
84     %res = arith.subf %arg0, %arg1 : f32
85     return %res : f32
86   }
87   )mlir";
88   DialectRegistry registry;
89   registerAllDialects(registry);
90   registerLLVMDialectTranslation(registry);
91   MLIRContext context(registry);
92   OwningOpRef<ModuleOp> module =
93       parseSourceString<ModuleOp>(moduleStr, &context);
94   ASSERT_TRUE(!!module);
95   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
96   auto jitOrError = ExecutionEngine::create(*module);
97   ASSERT_TRUE(!!jitOrError);
98   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
99   // The result of the function must be passed as output argument.
100   float result = -1;
101   llvm::Error error =
102       jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result));
103   ASSERT_TRUE(!error);
104   ASSERT_EQ(result, 42.f);
105 }
106 
107 TEST(NativeMemRefJit, ZeroRankMemref) {
108   OwningMemRef<float, 0> a({});
109   a[{}] = 42.;
110   ASSERT_EQ(*a->data, 42);
111   a[{}] = 0;
112   std::string moduleStr = R"mlir(
113   func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } {
114     %cst42 = arith.constant 42.0 : f32
115     memref.store %cst42, %arg0[] : memref<f32>
116     return
117   }
118   )mlir";
119   DialectRegistry registry;
120   registerAllDialects(registry);
121   registerLLVMDialectTranslation(registry);
122   MLIRContext context(registry);
123   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
124   ASSERT_TRUE(!!module);
125   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
126   auto jitOrError = ExecutionEngine::create(*module);
127   ASSERT_TRUE(!!jitOrError);
128   auto jit = std::move(jitOrError.get());
129 
130   llvm::Error error = jit->invoke("zero_ranked", &*a);
131   ASSERT_TRUE(!error);
132   EXPECT_EQ((a[{}]), 42.);
133   for (float &elt : *a)
134     EXPECT_EQ(&elt, &(a[{}]));
135 }
136 
137 TEST(NativeMemRefJit, RankOneMemref) {
138   int64_t shape[] = {9};
139   OwningMemRef<float, 1> a(shape);
140   int count = 1;
141   for (float &elt : *a) {
142     EXPECT_EQ(&elt, &(a[{count - 1}]));
143     elt = count++;
144   }
145 
146   std::string moduleStr = R"mlir(
147   func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } {
148     %cst42 = arith.constant 42.0 : f32
149     %cst5 = arith.constant 5 : index
150     memref.store %cst42, %arg0[%cst5] : memref<?xf32>
151     return
152   }
153   )mlir";
154   DialectRegistry registry;
155   registerAllDialects(registry);
156   registerLLVMDialectTranslation(registry);
157   MLIRContext context(registry);
158   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
159   ASSERT_TRUE(!!module);
160   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
161   auto jitOrError = ExecutionEngine::create(*module);
162   ASSERT_TRUE(!!jitOrError);
163   auto jit = std::move(jitOrError.get());
164 
165   llvm::Error error = jit->invoke("one_ranked", &*a);
166   ASSERT_TRUE(!error);
167   count = 1;
168   for (float &elt : *a) {
169     if (count == 6)
170       EXPECT_EQ(elt, 42.);
171     else
172       EXPECT_EQ(elt, count);
173     count++;
174   }
175 }
176 
177 TEST(NativeMemRefJit, BasicMemref) {
178   constexpr int k = 3;
179   constexpr int m = 7;
180   // Prepare arguments beforehand.
181   auto init = [=](float &elt, ArrayRef<int64_t> indices) {
182     assert(indices.size() == 2);
183     elt = m * indices[0] + indices[1];
184   };
185   int64_t shape[] = {k, m};
186   int64_t shapeAlloc[] = {k + 1, m + 1};
187   OwningMemRef<float, 2> a(shape, shapeAlloc, init);
188   ASSERT_EQ(a->sizes[0], k);
189   ASSERT_EQ(a->sizes[1], m);
190   ASSERT_EQ(a->strides[0], m + 1);
191   ASSERT_EQ(a->strides[1], 1);
192   for (int i = 0; i < k; ++i) {
193     for (int j = 0; j < m; ++j) {
194       EXPECT_EQ((a[{i, j}]), i * m + j);
195       EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j]));
196     }
197   }
198   std::string moduleStr = R"mlir(
199   func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } {
200     %x = arith.constant 2 : index
201     %y = arith.constant 1 : index
202     %cst42 = arith.constant 42.0 : f32
203     memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32>
204     memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32>
205     return
206   }
207   )mlir";
208   DialectRegistry registry;
209   registerAllDialects(registry);
210   registerLLVMDialectTranslation(registry);
211   MLIRContext context(registry);
212   OwningOpRef<ModuleOp> module =
213       parseSourceString<ModuleOp>(moduleStr, &context);
214   ASSERT_TRUE(!!module);
215   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
216   auto jitOrError = ExecutionEngine::create(*module);
217   ASSERT_TRUE(!!jitOrError);
218   std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get());
219 
220   llvm::Error error = jit->invoke("rank2_memref", &*a, &*a);
221   ASSERT_TRUE(!error);
222   EXPECT_EQ(((*a)[1][2]), 42.);
223   EXPECT_EQ((a[{2, 1}]), 42.);
224 }
225 
226 // A helper function that will be called from the JIT
227 static void memrefMultiply(::StridedMemRefType<float, 2> *memref,
228                            int32_t coefficient) {
229   for (float &elt : *memref)
230     elt *= coefficient;
231 }
232 
233 TEST(NativeMemRefJit, JITCallback) {
234   constexpr int k = 2;
235   constexpr int m = 2;
236   int64_t shape[] = {k, m};
237   int64_t shapeAlloc[] = {k + 1, m + 1};
238   OwningMemRef<float, 2> a(shape, shapeAlloc);
239   int count = 1;
240   for (float &elt : *a)
241     elt = count++;
242 
243   std::string moduleStr = R"mlir(
244   func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32)  attributes { llvm.emit_c_interface }
245   func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } {
246     %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32>
247     call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> ()
248     return
249   }
250   )mlir";
251   DialectRegistry registry;
252   registerAllDialects(registry);
253   registerLLVMDialectTranslation(registry);
254   MLIRContext context(registry);
255   auto module = parseSourceString<ModuleOp>(moduleStr, &context);
256   ASSERT_TRUE(!!module);
257   ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module)));
258   auto jitOrError = ExecutionEngine::create(*module);
259   ASSERT_TRUE(!!jitOrError);
260   auto jit = std::move(jitOrError.get());
261   // Define any extra symbols so they're available at runtime.
262   jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) {
263     llvm::orc::SymbolMap symbolMap;
264     symbolMap[interner("_mlir_ciface_callback")] =
265         llvm::JITEvaluatedSymbol::fromPointer(memrefMultiply);
266     return symbolMap;
267   });
268 
269   int32_t coefficient = 3.;
270   llvm::Error error = jit->invoke("caller_for_callback", &*a, coefficient);
271   ASSERT_TRUE(!error);
272   count = 1;
273   for (float elt : *a)
274     ASSERT_EQ(elt, coefficient * count++);
275 }
276 
277 #endif // _WIN32
278