1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements helper functions to call common simple C functions in 10 // LLVMIR (e.g. amon others to support printing and debugging). 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "mlir/Support/LLVM.h" 19 20 using namespace mlir; 21 using namespace mlir::LLVM; 22 23 /// Helper functions to lookup or create the declaration for commonly used 24 /// external C function calls. The list of functions provided here must be 25 /// implemented separately (e.g. as part of a support runtime library or as 26 /// part of the libc). 27 static constexpr llvm::StringRef kPrintI64 = "printI64"; 28 static constexpr llvm::StringRef kPrintU64 = "printU64"; 29 static constexpr llvm::StringRef kPrintF32 = "printF32"; 30 static constexpr llvm::StringRef kPrintF64 = "printF64"; 31 static constexpr llvm::StringRef kPrintOpen = "printOpen"; 32 static constexpr llvm::StringRef kPrintClose = "printClose"; 33 static constexpr llvm::StringRef kPrintComma = "printComma"; 34 static constexpr llvm::StringRef kPrintNewline = "printNewline"; 35 static constexpr llvm::StringRef kMalloc = "malloc"; 36 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc"; 37 static constexpr llvm::StringRef kFree = "free"; 38 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; 39 40 /// Generic print function lookupOrCreate helper. 41 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name, 42 ArrayRef<Type> paramTypes, 43 Type resultType) { 44 auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name); 45 if (func) 46 return func; 47 OpBuilder b(moduleOp.getBodyRegion()); 48 return b.create<LLVM::LLVMFuncOp>( 49 moduleOp->getLoc(), name, 50 LLVM::LLVMFunctionType::get(resultType, paramTypes)); 51 } 52 53 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) { 54 return lookupOrCreateFn(moduleOp, kPrintI64, 55 IntegerType::get(moduleOp->getContext(), 64), 56 LLVM::LLVMVoidType::get(moduleOp->getContext())); 57 } 58 59 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) { 60 return lookupOrCreateFn(moduleOp, kPrintU64, 61 IntegerType::get(moduleOp->getContext(), 64), 62 LLVM::LLVMVoidType::get(moduleOp->getContext())); 63 } 64 65 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) { 66 return lookupOrCreateFn(moduleOp, kPrintF32, 67 Float32Type::get(moduleOp->getContext()), 68 LLVM::LLVMVoidType::get(moduleOp->getContext())); 69 } 70 71 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) { 72 return lookupOrCreateFn(moduleOp, kPrintF64, 73 Float64Type::get(moduleOp->getContext()), 74 LLVM::LLVMVoidType::get(moduleOp->getContext())); 75 } 76 77 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) { 78 return lookupOrCreateFn(moduleOp, kPrintOpen, {}, 79 LLVM::LLVMVoidType::get(moduleOp->getContext())); 80 } 81 82 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) { 83 return lookupOrCreateFn(moduleOp, kPrintClose, {}, 84 LLVM::LLVMVoidType::get(moduleOp->getContext())); 85 } 86 87 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) { 88 return lookupOrCreateFn(moduleOp, kPrintComma, {}, 89 LLVM::LLVMVoidType::get(moduleOp->getContext())); 90 } 91 92 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) { 93 return lookupOrCreateFn(moduleOp, kPrintNewline, {}, 94 LLVM::LLVMVoidType::get(moduleOp->getContext())); 95 } 96 97 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp, 98 Type indexType) { 99 return LLVM::lookupOrCreateFn( 100 moduleOp, kMalloc, indexType, 101 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); 102 } 103 104 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, 105 Type indexType) { 106 return LLVM::lookupOrCreateFn( 107 moduleOp, kAlignedAlloc, {indexType, indexType}, 108 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); 109 } 110 111 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) { 112 return LLVM::lookupOrCreateFn( 113 moduleOp, kFree, 114 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)), 115 LLVM::LLVMVoidType::get(moduleOp->getContext())); 116 } 117 118 LLVM::LLVMFuncOp 119 mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, 120 Type unrankedDescriptorType) { 121 return LLVM::lookupOrCreateFn( 122 moduleOp, kMemRefCopy, 123 ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType}, 124 LLVM::LLVMVoidType::get(moduleOp->getContext())); 125 } 126 127 Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc, 128 LLVM::LLVMFuncOp fn, 129 ValueRange paramTypes, 130 ArrayRef<Type> resultTypes) { 131 return b 132 .create<LLVM::CallOp>(loc, resultTypes, SymbolRefAttr::get(fn), 133 paramTypes) 134 ->getResults(); 135 } 136