1 //===- Invoke.cpp ------------------------------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 10 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 11 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 12 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 13 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" 14 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/ExecutionEngine/CRunnerUtils.h" 18 #include "mlir/ExecutionEngine/ExecutionEngine.h" 19 #include "mlir/ExecutionEngine/MemRefUtils.h" 20 #include "mlir/ExecutionEngine/RunnerUtils.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/InitAllDialects.h" 23 #include "mlir/Parser/Parser.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 26 #include "mlir/Target/LLVMIR/Export.h" 27 #include "llvm/Support/TargetSelect.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 #include "gmock/gmock.h" 31 32 using namespace mlir; 33 34 // The JIT isn't supported on Windows at that time 35 #ifndef _WIN32 36 37 static struct LLVMInitializer { 38 LLVMInitializer() { 39 llvm::InitializeNativeTarget(); 40 llvm::InitializeNativeTargetAsmPrinter(); 41 } 42 } initializer; 43 44 /// Simple conversion pipeline for the purpose of testing sources written in 45 /// dialects lowering to LLVM Dialect. 46 static LogicalResult lowerToLLVMDialect(ModuleOp module) { 47 PassManager pm(module.getContext()); 48 pm.addPass(mlir::createMemRefToLLVMPass()); 49 pm.addNestedPass<FuncOp>(mlir::arith::createConvertArithmeticToLLVMPass()); 50 pm.addPass(mlir::createConvertFuncToLLVMPass()); 51 pm.addPass(mlir::createReconcileUnrealizedCastsPass()); 52 return pm.run(module); 53 } 54 55 TEST(MLIRExecutionEngine, AddInteger) { 56 std::string moduleStr = R"mlir( 57 func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { 58 %res = arith.addi %arg0, %arg0 : i32 59 return %res : i32 60 } 61 )mlir"; 62 DialectRegistry registry; 63 registerAllDialects(registry); 64 registerLLVMDialectTranslation(registry); 65 MLIRContext context(registry); 66 OwningOpRef<ModuleOp> module = 67 parseSourceString<ModuleOp>(moduleStr, &context); 68 ASSERT_TRUE(!!module); 69 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 70 auto jitOrError = ExecutionEngine::create(*module); 71 ASSERT_TRUE(!!jitOrError); 72 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 73 // The result of the function must be passed as output argument. 74 int result = 0; 75 llvm::Error error = 76 jit->invoke("foo", 42, ExecutionEngine::Result<int>(result)); 77 ASSERT_TRUE(!error); 78 ASSERT_EQ(result, 42 + 42); 79 } 80 81 TEST(MLIRExecutionEngine, SubtractFloat) { 82 std::string moduleStr = R"mlir( 83 func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } { 84 %res = arith.subf %arg0, %arg1 : f32 85 return %res : f32 86 } 87 )mlir"; 88 DialectRegistry registry; 89 registerAllDialects(registry); 90 registerLLVMDialectTranslation(registry); 91 MLIRContext context(registry); 92 OwningOpRef<ModuleOp> module = 93 parseSourceString<ModuleOp>(moduleStr, &context); 94 ASSERT_TRUE(!!module); 95 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 96 auto jitOrError = ExecutionEngine::create(*module); 97 ASSERT_TRUE(!!jitOrError); 98 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 99 // The result of the function must be passed as output argument. 100 float result = -1; 101 llvm::Error error = 102 jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result)); 103 ASSERT_TRUE(!error); 104 ASSERT_EQ(result, 42.f); 105 } 106 107 TEST(NativeMemRefJit, ZeroRankMemref) { 108 OwningMemRef<float, 0> a({}); 109 a[{}] = 42.; 110 ASSERT_EQ(*a->data, 42); 111 a[{}] = 0; 112 std::string moduleStr = R"mlir( 113 func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } { 114 %cst42 = arith.constant 42.0 : f32 115 memref.store %cst42, %arg0[] : memref<f32> 116 return 117 } 118 )mlir"; 119 DialectRegistry registry; 120 registerAllDialects(registry); 121 registerLLVMDialectTranslation(registry); 122 MLIRContext context(registry); 123 auto module = parseSourceString<ModuleOp>(moduleStr, &context); 124 ASSERT_TRUE(!!module); 125 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 126 auto jitOrError = ExecutionEngine::create(*module); 127 ASSERT_TRUE(!!jitOrError); 128 auto jit = std::move(jitOrError.get()); 129 130 llvm::Error error = jit->invoke("zero_ranked", &*a); 131 ASSERT_TRUE(!error); 132 EXPECT_EQ((a[{}]), 42.); 133 for (float &elt : *a) 134 EXPECT_EQ(&elt, &(a[{}])); 135 } 136 137 TEST(NativeMemRefJit, RankOneMemref) { 138 int64_t shape[] = {9}; 139 OwningMemRef<float, 1> a(shape); 140 int count = 1; 141 for (float &elt : *a) { 142 EXPECT_EQ(&elt, &(a[{count - 1}])); 143 elt = count++; 144 } 145 146 std::string moduleStr = R"mlir( 147 func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } { 148 %cst42 = arith.constant 42.0 : f32 149 %cst5 = arith.constant 5 : index 150 memref.store %cst42, %arg0[%cst5] : memref<?xf32> 151 return 152 } 153 )mlir"; 154 DialectRegistry registry; 155 registerAllDialects(registry); 156 registerLLVMDialectTranslation(registry); 157 MLIRContext context(registry); 158 auto module = parseSourceString<ModuleOp>(moduleStr, &context); 159 ASSERT_TRUE(!!module); 160 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 161 auto jitOrError = ExecutionEngine::create(*module); 162 ASSERT_TRUE(!!jitOrError); 163 auto jit = std::move(jitOrError.get()); 164 165 llvm::Error error = jit->invoke("one_ranked", &*a); 166 ASSERT_TRUE(!error); 167 count = 1; 168 for (float &elt : *a) { 169 if (count == 6) 170 EXPECT_EQ(elt, 42.); 171 else 172 EXPECT_EQ(elt, count); 173 count++; 174 } 175 } 176 177 TEST(NativeMemRefJit, BasicMemref) { 178 constexpr int k = 3; 179 constexpr int m = 7; 180 // Prepare arguments beforehand. 181 auto init = [=](float &elt, ArrayRef<int64_t> indices) { 182 assert(indices.size() == 2); 183 elt = m * indices[0] + indices[1]; 184 }; 185 int64_t shape[] = {k, m}; 186 int64_t shapeAlloc[] = {k + 1, m + 1}; 187 OwningMemRef<float, 2> a(shape, shapeAlloc, init); 188 ASSERT_EQ(a->sizes[0], k); 189 ASSERT_EQ(a->sizes[1], m); 190 ASSERT_EQ(a->strides[0], m + 1); 191 ASSERT_EQ(a->strides[1], 1); 192 for (int i = 0; i < k; ++i) { 193 for (int j = 0; j < m; ++j) { 194 EXPECT_EQ((a[{i, j}]), i * m + j); 195 EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j])); 196 } 197 } 198 std::string moduleStr = R"mlir( 199 func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } { 200 %x = arith.constant 2 : index 201 %y = arith.constant 1 : index 202 %cst42 = arith.constant 42.0 : f32 203 memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32> 204 memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32> 205 return 206 } 207 )mlir"; 208 DialectRegistry registry; 209 registerAllDialects(registry); 210 registerLLVMDialectTranslation(registry); 211 MLIRContext context(registry); 212 OwningOpRef<ModuleOp> module = 213 parseSourceString<ModuleOp>(moduleStr, &context); 214 ASSERT_TRUE(!!module); 215 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 216 auto jitOrError = ExecutionEngine::create(*module); 217 ASSERT_TRUE(!!jitOrError); 218 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 219 220 llvm::Error error = jit->invoke("rank2_memref", &*a, &*a); 221 ASSERT_TRUE(!error); 222 EXPECT_EQ(((*a)[1][2]), 42.); 223 EXPECT_EQ((a[{2, 1}]), 42.); 224 } 225 226 // A helper function that will be called from the JIT 227 static void memrefMultiply(::StridedMemRefType<float, 2> *memref, 228 int32_t coefficient) { 229 for (float &elt : *memref) 230 elt *= coefficient; 231 } 232 233 TEST(NativeMemRefJit, JITCallback) { 234 constexpr int k = 2; 235 constexpr int m = 2; 236 int64_t shape[] = {k, m}; 237 int64_t shapeAlloc[] = {k + 1, m + 1}; 238 OwningMemRef<float, 2> a(shape, shapeAlloc); 239 int count = 1; 240 for (float &elt : *a) 241 elt = count++; 242 243 std::string moduleStr = R"mlir( 244 func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } 245 func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } { 246 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> 247 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () 248 return 249 } 250 )mlir"; 251 DialectRegistry registry; 252 registerAllDialects(registry); 253 registerLLVMDialectTranslation(registry); 254 MLIRContext context(registry); 255 auto module = parseSourceString<ModuleOp>(moduleStr, &context); 256 ASSERT_TRUE(!!module); 257 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 258 auto jitOrError = ExecutionEngine::create(*module); 259 ASSERT_TRUE(!!jitOrError); 260 auto jit = std::move(jitOrError.get()); 261 // Define any extra symbols so they're available at runtime. 262 jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { 263 llvm::orc::SymbolMap symbolMap; 264 symbolMap[interner("_mlir_ciface_callback")] = 265 llvm::JITEvaluatedSymbol::fromPointer(memrefMultiply); 266 return symbolMap; 267 }); 268 269 int32_t coefficient = 3.; 270 llvm::Error error = jit->invoke("caller_for_callback", &*a, coefficient); 271 ASSERT_TRUE(!error); 272 count = 1; 273 for (float elt : *a) 274 ASSERT_EQ(elt, coefficient * count++); 275 } 276 277 #endif // _WIN32 278