1 //===- Invoke.cpp ------------------------------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 10 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 11 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 12 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 13 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" 14 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 16 #include "mlir/Dialect/Func/IR/FuncOps.h" 17 #include "mlir/Dialect/Linalg/Passes.h" 18 #include "mlir/ExecutionEngine/CRunnerUtils.h" 19 #include "mlir/ExecutionEngine/ExecutionEngine.h" 20 #include "mlir/ExecutionEngine/MemRefUtils.h" 21 #include "mlir/ExecutionEngine/RunnerUtils.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/InitAllDialects.h" 24 #include "mlir/Parser/Parser.h" 25 #include "mlir/Pass/PassManager.h" 26 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 27 #include "mlir/Target/LLVMIR/Export.h" 28 #include "llvm/Support/TargetSelect.h" 29 #include "llvm/Support/raw_ostream.h" 30 31 #include "gmock/gmock.h" 32 33 using namespace mlir; 34 35 // The JIT isn't supported on Windows at that time 36 #ifndef _WIN32 37 38 static struct LLVMInitializer { 39 LLVMInitializer() { 40 llvm::InitializeNativeTarget(); 41 llvm::InitializeNativeTargetAsmPrinter(); 42 } 43 } initializer; 44 45 /// Simple conversion pipeline for the purpose of testing sources written in 46 /// dialects lowering to LLVM Dialect. 47 static LogicalResult lowerToLLVMDialect(ModuleOp module) { 48 PassManager pm(module.getContext()); 49 pm.addPass(mlir::createMemRefToLLVMPass()); 50 pm.addNestedPass<func::FuncOp>( 51 mlir::arith::createConvertArithmeticToLLVMPass()); 52 pm.addPass(mlir::createConvertFuncToLLVMPass()); 53 pm.addPass(mlir::createReconcileUnrealizedCastsPass()); 54 return pm.run(module); 55 } 56 57 TEST(MLIRExecutionEngine, AddInteger) { 58 std::string moduleStr = R"mlir( 59 func.func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { 60 %res = arith.addi %arg0, %arg0 : i32 61 return %res : i32 62 } 63 )mlir"; 64 DialectRegistry registry; 65 registerAllDialects(registry); 66 registerLLVMDialectTranslation(registry); 67 MLIRContext context(registry); 68 OwningOpRef<ModuleOp> module = 69 parseSourceString<ModuleOp>(moduleStr, &context); 70 ASSERT_TRUE(!!module); 71 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 72 auto jitOrError = ExecutionEngine::create(*module); 73 ASSERT_TRUE(!!jitOrError); 74 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 75 // The result of the function must be passed as output argument. 76 int result = 0; 77 llvm::Error error = 78 jit->invoke("foo", 42, ExecutionEngine::Result<int>(result)); 79 ASSERT_TRUE(!error); 80 ASSERT_EQ(result, 42 + 42); 81 } 82 83 TEST(MLIRExecutionEngine, SubtractFloat) { 84 std::string moduleStr = R"mlir( 85 func.func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } { 86 %res = arith.subf %arg0, %arg1 : f32 87 return %res : f32 88 } 89 )mlir"; 90 DialectRegistry registry; 91 registerAllDialects(registry); 92 registerLLVMDialectTranslation(registry); 93 MLIRContext context(registry); 94 OwningOpRef<ModuleOp> module = 95 parseSourceString<ModuleOp>(moduleStr, &context); 96 ASSERT_TRUE(!!module); 97 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 98 auto jitOrError = ExecutionEngine::create(*module); 99 ASSERT_TRUE(!!jitOrError); 100 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 101 // The result of the function must be passed as output argument. 102 float result = -1; 103 llvm::Error error = 104 jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result)); 105 ASSERT_TRUE(!error); 106 ASSERT_EQ(result, 42.f); 107 } 108 109 TEST(NativeMemRefJit, ZeroRankMemref) { 110 OwningMemRef<float, 0> a({}); 111 a[{}] = 42.; 112 ASSERT_EQ(*a->data, 42); 113 a[{}] = 0; 114 std::string moduleStr = R"mlir( 115 func.func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } { 116 %cst42 = arith.constant 42.0 : f32 117 memref.store %cst42, %arg0[] : memref<f32> 118 return 119 } 120 )mlir"; 121 DialectRegistry registry; 122 registerAllDialects(registry); 123 registerLLVMDialectTranslation(registry); 124 MLIRContext context(registry); 125 auto module = parseSourceString<ModuleOp>(moduleStr, &context); 126 ASSERT_TRUE(!!module); 127 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 128 auto jitOrError = ExecutionEngine::create(*module); 129 ASSERT_TRUE(!!jitOrError); 130 auto jit = std::move(jitOrError.get()); 131 132 llvm::Error error = jit->invoke("zero_ranked", &*a); 133 ASSERT_TRUE(!error); 134 EXPECT_EQ((a[{}]), 42.); 135 for (float &elt : *a) 136 EXPECT_EQ(&elt, &(a[{}])); 137 } 138 139 TEST(NativeMemRefJit, RankOneMemref) { 140 int64_t shape[] = {9}; 141 OwningMemRef<float, 1> a(shape); 142 int count = 1; 143 for (float &elt : *a) { 144 EXPECT_EQ(&elt, &(a[{count - 1}])); 145 elt = count++; 146 } 147 148 std::string moduleStr = R"mlir( 149 func.func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } { 150 %cst42 = arith.constant 42.0 : f32 151 %cst5 = arith.constant 5 : index 152 memref.store %cst42, %arg0[%cst5] : memref<?xf32> 153 return 154 } 155 )mlir"; 156 DialectRegistry registry; 157 registerAllDialects(registry); 158 registerLLVMDialectTranslation(registry); 159 MLIRContext context(registry); 160 auto module = parseSourceString<ModuleOp>(moduleStr, &context); 161 ASSERT_TRUE(!!module); 162 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 163 auto jitOrError = ExecutionEngine::create(*module); 164 ASSERT_TRUE(!!jitOrError); 165 auto jit = std::move(jitOrError.get()); 166 167 llvm::Error error = jit->invoke("one_ranked", &*a); 168 ASSERT_TRUE(!error); 169 count = 1; 170 for (float &elt : *a) { 171 if (count == 6) 172 EXPECT_EQ(elt, 42.); 173 else 174 EXPECT_EQ(elt, count); 175 count++; 176 } 177 } 178 179 TEST(NativeMemRefJit, BasicMemref) { 180 constexpr int k = 3; 181 constexpr int m = 7; 182 // Prepare arguments beforehand. 183 auto init = [=](float &elt, ArrayRef<int64_t> indices) { 184 assert(indices.size() == 2); 185 elt = m * indices[0] + indices[1]; 186 }; 187 int64_t shape[] = {k, m}; 188 int64_t shapeAlloc[] = {k + 1, m + 1}; 189 OwningMemRef<float, 2> a(shape, shapeAlloc, init); 190 ASSERT_EQ(a->sizes[0], k); 191 ASSERT_EQ(a->sizes[1], m); 192 ASSERT_EQ(a->strides[0], m + 1); 193 ASSERT_EQ(a->strides[1], 1); 194 for (int i = 0; i < k; ++i) { 195 for (int j = 0; j < m; ++j) { 196 EXPECT_EQ((a[{i, j}]), i * m + j); 197 EXPECT_EQ(&(a[{i, j}]), &((*a)[i][j])); 198 } 199 } 200 std::string moduleStr = R"mlir( 201 func.func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } { 202 %x = arith.constant 2 : index 203 %y = arith.constant 1 : index 204 %cst42 = arith.constant 42.0 : f32 205 memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32> 206 memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32> 207 return 208 } 209 )mlir"; 210 DialectRegistry registry; 211 registerAllDialects(registry); 212 registerLLVMDialectTranslation(registry); 213 MLIRContext context(registry); 214 OwningOpRef<ModuleOp> module = 215 parseSourceString<ModuleOp>(moduleStr, &context); 216 ASSERT_TRUE(!!module); 217 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 218 auto jitOrError = ExecutionEngine::create(*module); 219 ASSERT_TRUE(!!jitOrError); 220 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 221 222 llvm::Error error = jit->invoke("rank2_memref", &*a, &*a); 223 ASSERT_TRUE(!error); 224 EXPECT_EQ(((*a)[1][2]), 42.); 225 EXPECT_EQ((a[{2, 1}]), 42.); 226 } 227 228 // A helper function that will be called from the JIT 229 static void memrefMultiply(::StridedMemRefType<float, 2> *memref, 230 int32_t coefficient) { 231 for (float &elt : *memref) 232 elt *= coefficient; 233 } 234 235 // MSAN does not work with JIT. 236 #if __has_feature(memory_sanitizer) 237 #define MAYBE_JITCallback DISABLED_JITCallback 238 #else 239 #define MAYBE_JITCallback JITCallback 240 #endif 241 TEST(NativeMemRefJit, MAYBE_JITCallback) { 242 constexpr int k = 2; 243 constexpr int m = 2; 244 int64_t shape[] = {k, m}; 245 int64_t shapeAlloc[] = {k + 1, m + 1}; 246 OwningMemRef<float, 2> a(shape, shapeAlloc); 247 int count = 1; 248 for (float &elt : *a) 249 elt = count++; 250 251 std::string moduleStr = R"mlir( 252 func.func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } 253 func.func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } { 254 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> 255 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () 256 return 257 } 258 )mlir"; 259 DialectRegistry registry; 260 registerAllDialects(registry); 261 registerLLVMDialectTranslation(registry); 262 MLIRContext context(registry); 263 auto module = parseSourceString<ModuleOp>(moduleStr, &context); 264 ASSERT_TRUE(!!module); 265 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 266 auto jitOrError = ExecutionEngine::create(*module); 267 ASSERT_TRUE(!!jitOrError); 268 auto jit = std::move(jitOrError.get()); 269 // Define any extra symbols so they're available at runtime. 270 jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { 271 llvm::orc::SymbolMap symbolMap; 272 symbolMap[interner("_mlir_ciface_callback")] = 273 llvm::JITEvaluatedSymbol::fromPointer(memrefMultiply); 274 return symbolMap; 275 }); 276 277 int32_t coefficient = 3.; 278 llvm::Error error = jit->invoke("caller_for_callback", &*a, coefficient); 279 ASSERT_TRUE(!error); 280 count = 1; 281 for (float elt : *a) 282 ASSERT_EQ(elt, coefficient * count++); 283 } 284 285 #endif // _WIN32 286