1 //===- Invoke.cpp ------------------------------------*- C++ -*-===// 2 // 3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" 10 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 11 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 12 #include "mlir/Conversion/VectorToSCF/VectorToSCF.h" 13 #include "mlir/Dialect/Linalg/Passes.h" 14 #include "mlir/ExecutionEngine/CRunnerUtils.h" 15 #include "mlir/ExecutionEngine/ExecutionEngine.h" 16 #include "mlir/ExecutionEngine/MemRefUtils.h" 17 #include "mlir/ExecutionEngine/RunnerUtils.h" 18 #include "mlir/IR/MLIRContext.h" 19 #include "mlir/InitAllDialects.h" 20 #include "mlir/Parser.h" 21 #include "mlir/Pass/PassManager.h" 22 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" 23 #include "mlir/Target/LLVMIR/Export.h" 24 #include "llvm/Support/TargetSelect.h" 25 #include "llvm/Support/raw_ostream.h" 26 27 #include "gmock/gmock.h" 28 29 using namespace mlir; 30 31 static struct LLVMInitializer { 32 LLVMInitializer() { 33 llvm::InitializeNativeTarget(); 34 llvm::InitializeNativeTargetAsmPrinter(); 35 } 36 } initializer; 37 38 /// Simple conversion pipeline for the purpose of testing sources written in 39 /// dialects lowering to LLVM Dialect. 40 static LogicalResult lowerToLLVMDialect(ModuleOp module) { 41 PassManager pm(module.getContext()); 42 pm.addPass(mlir::createLowerToLLVMPass()); 43 return pm.run(module); 44 } 45 46 // The JIT isn't supported on Windows at that time 47 #ifndef _WIN32 48 49 TEST(MLIRExecutionEngine, AddInteger) { 50 std::string moduleStr = R"mlir( 51 func @foo(%arg0 : i32) -> i32 attributes { llvm.emit_c_interface } { 52 %res = std.addi %arg0, %arg0 : i32 53 return %res : i32 54 } 55 )mlir"; 56 DialectRegistry registry; 57 registerAllDialects(registry); 58 registerLLVMDialectTranslation(registry); 59 MLIRContext context(registry); 60 OwningModuleRef module = parseSourceString(moduleStr, &context); 61 ASSERT_TRUE(!!module); 62 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 63 auto jitOrError = ExecutionEngine::create(*module); 64 ASSERT_TRUE(!!jitOrError); 65 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 66 // The result of the function must be passed as output argument. 67 int result = 0; 68 llvm::Error error = 69 jit->invoke("foo", 42, ExecutionEngine::Result<int>(result)); 70 ASSERT_TRUE(!error); 71 ASSERT_EQ(result, 42 + 42); 72 } 73 74 TEST(MLIRExecutionEngine, SubtractFloat) { 75 std::string moduleStr = R"mlir( 76 func @foo(%arg0 : f32, %arg1 : f32) -> f32 attributes { llvm.emit_c_interface } { 77 %res = std.subf %arg0, %arg1 : f32 78 return %res : f32 79 } 80 )mlir"; 81 DialectRegistry registry; 82 registerAllDialects(registry); 83 registerLLVMDialectTranslation(registry); 84 MLIRContext context(registry); 85 OwningModuleRef module = parseSourceString(moduleStr, &context); 86 ASSERT_TRUE(!!module); 87 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 88 auto jitOrError = ExecutionEngine::create(*module); 89 ASSERT_TRUE(!!jitOrError); 90 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 91 // The result of the function must be passed as output argument. 92 float result = -1; 93 llvm::Error error = 94 jit->invoke("foo", 43.0f, 1.0f, ExecutionEngine::result(result)); 95 ASSERT_TRUE(!error); 96 ASSERT_EQ(result, 42.f); 97 } 98 99 TEST(NativeMemRefJit, ZeroRankMemref) { 100 OwningMemRef<float, 0> A({}); 101 A[{}] = 42.; 102 ASSERT_EQ(*A->data, 42); 103 A[{}] = 0; 104 std::string moduleStr = R"mlir( 105 func @zero_ranked(%arg0 : memref<f32>) attributes { llvm.emit_c_interface } { 106 %cst42 = constant 42.0 : f32 107 memref.store %cst42, %arg0[] : memref<f32> 108 return 109 } 110 )mlir"; 111 DialectRegistry registry; 112 registerAllDialects(registry); 113 registerLLVMDialectTranslation(registry); 114 MLIRContext context(registry); 115 auto module = parseSourceString(moduleStr, &context); 116 ASSERT_TRUE(!!module); 117 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 118 auto jitOrError = ExecutionEngine::create(*module); 119 ASSERT_TRUE(!!jitOrError); 120 auto jit = std::move(jitOrError.get()); 121 122 llvm::Error error = jit->invoke("zero_ranked", &*A); 123 ASSERT_TRUE(!error); 124 EXPECT_EQ((A[{}]), 42.); 125 for (float &elt : *A) 126 EXPECT_EQ(&elt, &(A[{}])); 127 } 128 129 TEST(NativeMemRefJit, RankOneMemref) { 130 int64_t shape[] = {9}; 131 OwningMemRef<float, 1> A(shape); 132 int count = 1; 133 for (float &elt : *A) { 134 EXPECT_EQ(&elt, &(A[{count - 1}])); 135 elt = count++; 136 } 137 138 std::string moduleStr = R"mlir( 139 func @one_ranked(%arg0 : memref<?xf32>) attributes { llvm.emit_c_interface } { 140 %cst42 = constant 42.0 : f32 141 %cst5 = constant 5 : index 142 memref.store %cst42, %arg0[%cst5] : memref<?xf32> 143 return 144 } 145 )mlir"; 146 DialectRegistry registry; 147 registerAllDialects(registry); 148 registerLLVMDialectTranslation(registry); 149 MLIRContext context(registry); 150 auto module = parseSourceString(moduleStr, &context); 151 ASSERT_TRUE(!!module); 152 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 153 auto jitOrError = ExecutionEngine::create(*module); 154 ASSERT_TRUE(!!jitOrError); 155 auto jit = std::move(jitOrError.get()); 156 157 llvm::Error error = jit->invoke("one_ranked", &*A); 158 ASSERT_TRUE(!error); 159 count = 1; 160 for (float &elt : *A) { 161 if (count == 6) 162 EXPECT_EQ(elt, 42.); 163 else 164 EXPECT_EQ(elt, count); 165 count++; 166 } 167 } 168 169 TEST(NativeMemRefJit, BasicMemref) { 170 constexpr int K = 3; 171 constexpr int M = 7; 172 // Prepare arguments beforehand. 173 auto init = [=](float &elt, ArrayRef<int64_t> indices) { 174 assert(indices.size() == 2); 175 elt = M * indices[0] + indices[1]; 176 }; 177 int64_t shape[] = {K, M}; 178 int64_t shapeAlloc[] = {K + 1, M + 1}; 179 OwningMemRef<float, 2> A(shape, shapeAlloc, init); 180 ASSERT_EQ(A->sizes[0], K); 181 ASSERT_EQ(A->sizes[1], M); 182 ASSERT_EQ(A->strides[0], M + 1); 183 ASSERT_EQ(A->strides[1], 1); 184 for (int i = 0; i < K; ++i) { 185 for (int j = 0; j < M; ++j) { 186 EXPECT_EQ((A[{i, j}]), i * M + j); 187 EXPECT_EQ(&(A[{i, j}]), &((*A)[i][j])); 188 } 189 } 190 std::string moduleStr = R"mlir( 191 func @rank2_memref(%arg0 : memref<?x?xf32>, %arg1 : memref<?x?xf32>) attributes { llvm.emit_c_interface } { 192 %x = constant 2 : index 193 %y = constant 1 : index 194 %cst42 = constant 42.0 : f32 195 memref.store %cst42, %arg0[%y, %x] : memref<?x?xf32> 196 memref.store %cst42, %arg1[%x, %y] : memref<?x?xf32> 197 return 198 } 199 )mlir"; 200 DialectRegistry registry; 201 registerAllDialects(registry); 202 registerLLVMDialectTranslation(registry); 203 MLIRContext context(registry); 204 OwningModuleRef module = parseSourceString(moduleStr, &context); 205 ASSERT_TRUE(!!module); 206 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 207 auto jitOrError = ExecutionEngine::create(*module); 208 ASSERT_TRUE(!!jitOrError); 209 std::unique_ptr<ExecutionEngine> jit = std::move(jitOrError.get()); 210 211 llvm::Error error = jit->invoke("rank2_memref", &*A, &*A); 212 ASSERT_TRUE(!error); 213 EXPECT_EQ(((*A)[1][2]), 42.); 214 EXPECT_EQ((A[{2, 1}]), 42.); 215 } 216 217 // A helper function that will be called from the JIT 218 static void memref_multiply(::StridedMemRefType<float, 2> *memref, 219 int32_t coefficient) { 220 for (float &elt : *memref) 221 elt *= coefficient; 222 } 223 224 TEST(NativeMemRefJit, JITCallback) { 225 constexpr int K = 2; 226 constexpr int M = 2; 227 int64_t shape[] = {K, M}; 228 int64_t shapeAlloc[] = {K + 1, M + 1}; 229 OwningMemRef<float, 2> A(shape, shapeAlloc); 230 int count = 1; 231 for (float &elt : *A) 232 elt = count++; 233 234 std::string moduleStr = R"mlir( 235 func private @callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } 236 func @caller_for_callback(%arg0: memref<?x?xf32>, %coefficient: i32) attributes { llvm.emit_c_interface } { 237 %unranked = memref.cast %arg0: memref<?x?xf32> to memref<*xf32> 238 call @callback(%arg0, %coefficient) : (memref<?x?xf32>, i32) -> () 239 return 240 } 241 )mlir"; 242 DialectRegistry registry; 243 registerAllDialects(registry); 244 registerLLVMDialectTranslation(registry); 245 MLIRContext context(registry); 246 auto module = parseSourceString(moduleStr, &context); 247 ASSERT_TRUE(!!module); 248 ASSERT_TRUE(succeeded(lowerToLLVMDialect(*module))); 249 auto jitOrError = ExecutionEngine::create(*module); 250 ASSERT_TRUE(!!jitOrError); 251 auto jit = std::move(jitOrError.get()); 252 // Define any extra symbols so they're available at runtime. 253 jit->registerSymbols([&](llvm::orc::MangleAndInterner interner) { 254 llvm::orc::SymbolMap symbolMap; 255 symbolMap[interner("_mlir_ciface_callback")] = 256 llvm::JITEvaluatedSymbol::fromPointer(memref_multiply); 257 return symbolMap; 258 }); 259 260 int32_t coefficient = 3.; 261 llvm::Error error = jit->invoke("caller_for_callback", &*A, coefficient); 262 ASSERT_TRUE(!error); 263 count = 1; 264 for (float elt : *A) 265 ASSERT_EQ(elt, coefficient * count++); 266 } 267 268 #endif // _WIN32 269