1 //===- Invoke.cpp ------------------------------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 11 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" 12 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 13 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 14 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 15 #include "mlir/Dialect/Linalg/Passes.h" 16 #include "mlir/ExecutionEngine/CRunnerUtils.h" 17 #include "mlir/ExecutionEngine/ExecutionEngine.h" 18 #include "mlir/ExecutionEngine/MemRefUtils.h" 19 #include "mlir/ExecutionEngine/RunnerUtils.h" 20 #include "mlir/IR/MLIRContext.h" 21 #include "mlir/InitAllDialects.h" 22 #include "mlir/Parser.h" 23 #include "mlir/Pass/PassManager.h" 24 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 25 #include "mlir/Target/LLVMIR/Export.h" 26 #include "llvm/Support/TargetSelect.h" 27 #include "llvm/Support/raw_ostream.h" 28 29 #include "gmock/gmock.h" 30 31 using namespace mlir; 32 33 static struct LLVMInitializer { 34 LLVMInitializer() { 35 llvm::InitializeNativeTarget(); 36 llvm::InitializeNativeTargetAsmPrinter(); 37 } 38 } initializer; 39 40 /// Simple conversion pipeline for the purpose of testing sources written in 41 /// dialects lowering to LLVM Dialect. 42 static LogicalResult lowerToLLVMDialect(ModuleOp module) { 43 PassManager pm(module.getContext()); 44 pm.addPass(mlir::createMemRefToLLVMPass()); 45 pm.addPass(mlir::createLowerToLLVMPass()); 46 pm.addPass(mlir::createReconcileUnrealizedCastsPass()); 47 return pm.run(module); 48 } 49 50 // The JIT isn't supported on Windows at that time 51 #ifndef _WIN32 52 53 TEST(MLIRExecutionEngine, AddInteger) { 54 std::string moduleStr = R"mlir( 55 func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { 56 %res = std.addi %arg0, %arg0 : i32 57 return %res : i32 58 } 59 )mlir"; 60 DialectRegistry registry; 61 registerAllDialects(registry); 62 registerLLVMDialectTranslation(registry); 63 MLIRContext context(registry); 64 OwningModuleRef module = parseSourceString(moduleStr, &context); 65 ASSERT_TRUE(!!module); 66 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 67 auto jitOrError = ExecutionEngine::create(*module); 68 ASSERT_TRUE(!!jitOrError); 69 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 70 // The result of the function must be passed as output argument. 71 int result = 0; 72 llvm::Error error = 73 jit->invoke("foo", 42, ExecutionEngine::Result<int>(result)); 74 ASSERT_TRUE(!error); 75 ASSERT_EQ(result, 42 + 42); 76 } 77 78 TEST(MLIRExecutionEngine, SubtractFloat) { 79 std::string moduleStr = R"mlir( 80 func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } { 81 %res = std.subf %arg0, %arg1 : f32 82 return %res : f32 83 } 84 )mlir"; 85 DialectRegistry registry; 86 registerAllDialects(registry); 87 registerLLVMDialectTranslation(registry); 88 MLIRContext context(registry); 89 OwningModuleRef module = parseSourceString(moduleStr, &context); 90 ASSERT_TRUE(!!module); 91 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 92 auto jitOrError = ExecutionEngine::create(*module); 93 ASSERT_TRUE(!!jitOrError); 94 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 95 // The result of the function must be passed as output argument. 96 float result = -1; 97 llvm::Error error = 98 jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result)); 99 ASSERT_TRUE(!error); 100 ASSERT_EQ(result, 42.f); 101 } 102 103 TEST(NativeMemRefJit, ZeroRankMemref) { 104 OwningMemRef<float, 0> A({}); 105 A[{}] = 42.; 106 ASSERT_EQ(*A->data, 42); 107 A[{}] = 0; 108 std::string moduleStr = R"mlir( 109 func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } { 110 %cst42 = constant 42.0 : f32 111 memref.store %cst42, %arg0[] : memref<f32> 112 return 113 } 114 )mlir"; 115 DialectRegistry registry; 116 registerAllDialects(registry); 117 registerLLVMDialectTranslation(registry); 118 MLIRContext context(registry); 119 auto module = parseSourceString(moduleStr, &context); 120 ASSERT_TRUE(!!module); 121 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 122 auto jitOrError = ExecutionEngine::create(*module); 123 ASSERT_TRUE(!!jitOrError); 124 auto jit = std::move(jitOrError.get()); 125 126 llvm::Error error = jit->invoke("zero_ranked", &*A); 127 ASSERT_TRUE(!error); 128 EXPECT_EQ((A[{}]), 42.); 129 for (float &elt : *A) 130 EXPECT_EQ(&elt, &(A[{}])); 131 } 132 133 TEST(NativeMemRefJit, RankOneMemref) { 134 int64_t shape[] = {9}; 135 OwningMemRef<float, 1> A(shape); 136 int count = 1; 137 for (float &elt : *A) { 138 EXPECT_EQ(&elt, &(A[{count - 1}])); 139 elt = count++; 140 } 141 142 std::string moduleStr = R"mlir( 143 func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } { 144 %cst42 = constant 42.0 : f32 145 %cst5 = constant 5 : index 146 memref.store %cst42, %arg0[%cst5] : memref<?xf32> 147 return 148 } 149 )mlir"; 150 DialectRegistry registry; 151 registerAllDialects(registry); 152 registerLLVMDialectTranslation(registry); 153 MLIRContext context(registry); 154 auto module = parseSourceString(moduleStr, &context); 155 ASSERT_TRUE(!!module); 156 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 157 auto jitOrError = ExecutionEngine::create(*module); 158 ASSERT_TRUE(!!jitOrError); 159 auto jit = std::move(jitOrError.get()); 160 161 llvm::Error error = jit->invoke("one_ranked", &*A); 162 ASSERT_TRUE(!error); 163 count = 1; 164 for (float &elt : *A) { 165 if (count == 6) 166 EXPECT_EQ(elt, 42.); 167 else 168 EXPECT_EQ(elt, count); 169 count++; 170 } 171 } 172 173 TEST(NativeMemRefJit, BasicMemref) { 174 constexpr int K = 3; 175 constexpr int M = 7; 176 // Prepare arguments beforehand. 177 auto init = [=](float &elt, ArrayRef<int64_t> indices) { 178 assert(indices.size() == 2); 179 elt = M * indices[0] + indices[1]; 180 }; 181 int64_t shape[] = {K, M}; 182 int64_t shapeAlloc[] = {K + 1, M + 1}; 183 OwningMemRef<float, 2> A(shape, shapeAlloc, init); 184 ASSERT_EQ(A->sizes[0], K); 185 ASSERT_EQ(A->sizes[1], M); 186 ASSERT_EQ(A->strides[0], M + 1); 187 ASSERT_EQ(A->strides[1], 1); 188 for (int i = 0; i < K; ++i) { 189 for (int j = 0; j < M; ++j) { 190 EXPECT_EQ((A[{i, j}]), i * M + j); 191 EXPECT_EQ(&(A[{i, j}]), &((*A)[i][j])); 192 } 193 } 194 std::string moduleStr = R"mlir( 195 func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } { 196 %x = constant 2 : index 197 %y = constant 1 : index 198 %cst42 = constant 42.0 : f32 199 memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32> 200 memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32> 201 return 202 } 203 )mlir"; 204 DialectRegistry registry; 205 registerAllDialects(registry); 206 registerLLVMDialectTranslation(registry); 207 MLIRContext context(registry); 208 OwningModuleRef module = parseSourceString(moduleStr, &context); 209 ASSERT_TRUE(!!module); 210 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 211 auto jitOrError = ExecutionEngine::create(*module); 212 ASSERT_TRUE(!!jitOrError); 213 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 214 215 llvm::Error error = jit->invoke("rank2_memref", &*A, &*A); 216 ASSERT_TRUE(!error); 217 EXPECT_EQ(((*A)[1][2]), 42.); 218 EXPECT_EQ((A[{2, 1}]), 42.); 219 } 220 221 // A helper function that will be called from the JIT 222 static void memref_multiply(::StridedMemRefType<float, 2> *memref, 223 int32_t coefficient) { 224 for (float &elt : *memref) 225 elt *= coefficient; 226 } 227 228 TEST(NativeMemRefJit, JITCallback) { 229 constexpr int K = 2; 230 constexpr int M = 2; 231 int64_t shape[] = {K, M}; 232 int64_t shapeAlloc[] = {K + 1, M + 1}; 233 OwningMemRef<float, 2> A(shape, shapeAlloc); 234 int count = 1; 235 for (float &elt : *A) 236 elt = count++; 237 238 std::string moduleStr = R"mlir( 239 func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } 240 func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } { 241 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> 242 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () 243 return 244 } 245 )mlir"; 246 DialectRegistry registry; 247 registerAllDialects(registry); 248 registerLLVMDialectTranslation(registry); 249 MLIRContext context(registry); 250 auto module = parseSourceString(moduleStr, &context); 251 ASSERT_TRUE(!!module); 252 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 253 auto jitOrError = ExecutionEngine::create(*module); 254 ASSERT_TRUE(!!jitOrError); 255 auto jit = std::move(jitOrError.get()); 256 // Define any extra symbols so they're available at runtime. 257 jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { 258 llvm::orc::SymbolMap symbolMap; 259 symbolMap[interner("_mlir_ciface_callback")] = 260 llvm::JITEvaluatedSymbol::fromPointer(memref_multiply); 261 return symbolMap; 262 }); 263 264 int32_t coefficient = 3.; 265 llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient); 266 ASSERT_TRUE(!error); 267 count = 1; 268 for (float elt : *A) 269 ASSERT_EQ(elt, coefficient * count++); 270 } 271 272 #endif // _WIN32 273