1 //===- Invoke.cpp ------------------------------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 10 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 11 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 12 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" 13 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 14 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 15 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 16 #include "mlir/Dialect/Linalg/Passes.h" 17 #include "mlir/ExecutionEngine/CRunnerUtils.h" 18 #include "mlir/ExecutionEngine/ExecutionEngine.h" 19 #include "mlir/ExecutionEngine/MemRefUtils.h" 20 #include "mlir/ExecutionEngine/RunnerUtils.h" 21 #include "mlir/IR/MLIRContext.h" 22 #include "mlir/InitAllDialects.h" 23 #include "mlir/Parser.h" 24 #include "mlir/Pass/PassManager.h" 25 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 26 #include "mlir/Target/LLVMIR/Export.h" 27 #include "llvm/Support/TargetSelect.h" 28 #include "llvm/Support/raw_ostream.h" 29 30 #include "gmock/gmock.h" 31 32 using namespace mlir; 33 34 static struct LLVMInitializer { 35 LLVMInitializer() { 36 llvm::InitializeNativeTarget(); 37 llvm::InitializeNativeTargetAsmPrinter(); 38 } 39 } initializer; 40 41 /// Simple conversion pipeline for the purpose of testing sources written in 42 /// dialects lowering to LLVM Dialect. 43 static LogicalResult lowerToLLVMDialect(ModuleOp module) { 44 PassManager pm(module.getContext()); 45 pm.addPass(mlir::createMemRefToLLVMPass()); 46 pm.addNestedPass<FuncOp>(mlir::arith::createConvertArithmeticToLLVMPass()); 47 pm.addPass(mlir::createLowerToLLVMPass()); 48 pm.addPass(mlir::createReconcileUnrealizedCastsPass()); 49 return pm.run(module); 50 } 51 52 // The JIT isn't supported on Windows at that time 53 #ifndef _WIN32 54 55 TEST(MLIRExecutionEngine, AddInteger) { 56 std::string moduleStr = R"mlir( 57 func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { 58 %res = arith.addi %arg0, %arg0 : i32 59 return %res : i32 60 } 61 )mlir"; 62 DialectRegistry registry; 63 registerAllDialects(registry); 64 registerLLVMDialectTranslation(registry); 65 MLIRContext context(registry); 66 OwningModuleRef module = parseSourceString(moduleStr, &context); 67 ASSERT_TRUE(!!module); 68 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 69 auto jitOrError = ExecutionEngine::create(*module); 70 ASSERT_TRUE(!!jitOrError); 71 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 72 // The result of the function must be passed as output argument. 73 int result = 0; 74 llvm::Error error = 75 jit->invoke("foo", 42, ExecutionEngine::Result<int>(result)); 76 ASSERT_TRUE(!error); 77 ASSERT_EQ(result, 42 + 42); 78 } 79 80 TEST(MLIRExecutionEngine, SubtractFloat) { 81 std::string moduleStr = R"mlir( 82 func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } { 83 %res = arith.subf %arg0, %arg1 : f32 84 return %res : f32 85 } 86 )mlir"; 87 DialectRegistry registry; 88 registerAllDialects(registry); 89 registerLLVMDialectTranslation(registry); 90 MLIRContext context(registry); 91 OwningModuleRef module = parseSourceString(moduleStr, &context); 92 ASSERT_TRUE(!!module); 93 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 94 auto jitOrError = ExecutionEngine::create(*module); 95 ASSERT_TRUE(!!jitOrError); 96 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 97 // The result of the function must be passed as output argument. 98 float result = -1; 99 llvm::Error error = 100 jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result)); 101 ASSERT_TRUE(!error); 102 ASSERT_EQ(result, 42.f); 103 } 104 105 TEST(NativeMemRefJit, ZeroRankMemref) { 106 OwningMemRef<float, 0> A({}); 107 A[{}] = 42.; 108 ASSERT_EQ(*A->data, 42); 109 A[{}] = 0; 110 std::string moduleStr = R"mlir( 111 func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } { 112 %cst42 = arith.constant 42.0 : f32 113 memref.store %cst42, %arg0[] : memref<f32> 114 return 115 } 116 )mlir"; 117 DialectRegistry registry; 118 registerAllDialects(registry); 119 registerLLVMDialectTranslation(registry); 120 MLIRContext context(registry); 121 auto module = parseSourceString(moduleStr, &context); 122 ASSERT_TRUE(!!module); 123 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 124 auto jitOrError = ExecutionEngine::create(*module); 125 ASSERT_TRUE(!!jitOrError); 126 auto jit = std::move(jitOrError.get()); 127 128 llvm::Error error = jit->invoke("zero_ranked", &*A); 129 ASSERT_TRUE(!error); 130 EXPECT_EQ((A[{}]), 42.); 131 for (float &elt : *A) 132 EXPECT_EQ(&elt, &(A[{}])); 133 } 134 135 TEST(NativeMemRefJit, RankOneMemref) { 136 int64_t shape[] = {9}; 137 OwningMemRef<float, 1> A(shape); 138 int count = 1; 139 for (float &elt : *A) { 140 EXPECT_EQ(&elt, &(A[{count - 1}])); 141 elt = count++; 142 } 143 144 std::string moduleStr = R"mlir( 145 func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } { 146 %cst42 = arith.constant 42.0 : f32 147 %cst5 = arith.constant 5 : index 148 memref.store %cst42, %arg0[%cst5] : memref<?xf32> 149 return 150 } 151 )mlir"; 152 DialectRegistry registry; 153 registerAllDialects(registry); 154 registerLLVMDialectTranslation(registry); 155 MLIRContext context(registry); 156 auto module = parseSourceString(moduleStr, &context); 157 ASSERT_TRUE(!!module); 158 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 159 auto jitOrError = ExecutionEngine::create(*module); 160 ASSERT_TRUE(!!jitOrError); 161 auto jit = std::move(jitOrError.get()); 162 163 llvm::Error error = jit->invoke("one_ranked", &*A); 164 ASSERT_TRUE(!error); 165 count = 1; 166 for (float &elt : *A) { 167 if (count == 6) 168 EXPECT_EQ(elt, 42.); 169 else 170 EXPECT_EQ(elt, count); 171 count++; 172 } 173 } 174 175 TEST(NativeMemRefJit, BasicMemref) { 176 constexpr int K = 3; 177 constexpr int M = 7; 178 // Prepare arguments beforehand. 179 auto init = [=](float &elt, ArrayRef<int64_t> indices) { 180 assert(indices.size() == 2); 181 elt = M * indices[0] + indices[1]; 182 }; 183 int64_t shape[] = {K, M}; 184 int64_t shapeAlloc[] = {K + 1, M + 1}; 185 OwningMemRef<float, 2> A(shape, shapeAlloc, init); 186 ASSERT_EQ(A->sizes[0], K); 187 ASSERT_EQ(A->sizes[1], M); 188 ASSERT_EQ(A->strides[0], M + 1); 189 ASSERT_EQ(A->strides[1], 1); 190 for (int i = 0; i < K; ++i) { 191 for (int j = 0; j < M; ++j) { 192 EXPECT_EQ((A[{i, j}]), i * M + j); 193 EXPECT_EQ(&(A[{i, j}]), &((*A)[i][j])); 194 } 195 } 196 std::string moduleStr = R"mlir( 197 func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } { 198 %x = arith.constant 2 : index 199 %y = arith.constant 1 : index 200 %cst42 = arith.constant 42.0 : f32 201 memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32> 202 memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32> 203 return 204 } 205 )mlir"; 206 DialectRegistry registry; 207 registerAllDialects(registry); 208 registerLLVMDialectTranslation(registry); 209 MLIRContext context(registry); 210 OwningModuleRef module = parseSourceString(moduleStr, &context); 211 ASSERT_TRUE(!!module); 212 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 213 auto jitOrError = ExecutionEngine::create(*module); 214 ASSERT_TRUE(!!jitOrError); 215 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 216 217 llvm::Error error = jit->invoke("rank2_memref", &*A, &*A); 218 ASSERT_TRUE(!error); 219 EXPECT_EQ(((*A)[1][2]), 42.); 220 EXPECT_EQ((A[{2, 1}]), 42.); 221 } 222 223 // A helper function that will be called from the JIT 224 static void memref_multiply(::StridedMemRefType<float, 2> *memref, 225 int32_t coefficient) { 226 for (float &elt : *memref) 227 elt *= coefficient; 228 } 229 230 TEST(NativeMemRefJit, JITCallback) { 231 constexpr int K = 2; 232 constexpr int M = 2; 233 int64_t shape[] = {K, M}; 234 int64_t shapeAlloc[] = {K + 1, M + 1}; 235 OwningMemRef<float, 2> A(shape, shapeAlloc); 236 int count = 1; 237 for (float &elt : *A) 238 elt = count++; 239 240 std::string moduleStr = R"mlir( 241 func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } 242 func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } { 243 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> 244 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () 245 return 246 } 247 )mlir"; 248 DialectRegistry registry; 249 registerAllDialects(registry); 250 registerLLVMDialectTranslation(registry); 251 MLIRContext context(registry); 252 auto module = parseSourceString(moduleStr, &context); 253 ASSERT_TRUE(!!module); 254 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 255 auto jitOrError = ExecutionEngine::create(*module); 256 ASSERT_TRUE(!!jitOrError); 257 auto jit = std::move(jitOrError.get()); 258 // Define any extra symbols so they're available at runtime. 259 jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { 260 llvm::orc::SymbolMap symbolMap; 261 symbolMap[interner("_mlir_ciface_callback")] = 262 llvm::JITEvaluatedSymbol::fromPointer(memref_multiply); 263 return symbolMap; 264 }); 265 266 int32_t coefficient = 3.; 267 llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient); 268 ASSERT_TRUE(!error); 269 count = 1; 270 for (float elt : *A) 271 ASSERT_EQ(elt, coefficient * count++); 272 } 273 274 #endif // _WIN32 275