1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements helper functions to call common simple C functions in 10 // LLVMIR (e.g. amon others to support printing and debugging). 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "mlir/Support/LLVM.h" 19 20 using namespace mlir; 21 using namespace mlir::LLVM; 22 23 /// Helper functions to lookup or create the declaration for commonly used 24 /// external C function calls. The list of functions provided here must be 25 /// implemented separately (e.g. as part of a support runtime library or as 26 /// part of the libc). 27 static constexpr llvm::StringRef kPrintI64 = "printI64"; 28 static constexpr llvm::StringRef kPrintU64 = "printU64"; 29 static constexpr llvm::StringRef kPrintF32 = "printF32"; 30 static constexpr llvm::StringRef kPrintF64 = "printF64"; 31 static constexpr llvm::StringRef kPrintOpen = "printOpen"; 32 static constexpr llvm::StringRef kPrintClose = "printClose"; 33 static constexpr llvm::StringRef kPrintComma = "printComma"; 34 static constexpr llvm::StringRef kPrintNewline = "printNewline"; 35 static constexpr llvm::StringRef kMalloc = "malloc"; 36 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc"; 37 static constexpr llvm::StringRef kFree = "free"; 38 static constexpr llvm::StringRef kGenericAlloc = "_mlir_alloc"; 39 static constexpr llvm::StringRef kGenericAlignedAlloc = "_mlir_aligned_alloc"; 40 static constexpr llvm::StringRef kGenericFree = "_mlir_free"; 41 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; 42 43 /// Generic print function lookupOrCreate helper. 44 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name, 45 ArrayRef<Type> paramTypes, 46 Type resultType) { 47 auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name); 48 if (func) 49 return func; 50 OpBuilder b(moduleOp.getBodyRegion()); 51 return b.create<LLVM::LLVMFuncOp>( 52 moduleOp->getLoc(), name, 53 LLVM::LLVMFunctionType::get(resultType, paramTypes)); 54 } 55 56 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) { 57 return lookupOrCreateFn(moduleOp, kPrintI64, 58 IntegerType::get(moduleOp->getContext(), 64), 59 LLVM::LLVMVoidType::get(moduleOp->getContext())); 60 } 61 62 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) { 63 return lookupOrCreateFn(moduleOp, kPrintU64, 64 IntegerType::get(moduleOp->getContext(), 64), 65 LLVM::LLVMVoidType::get(moduleOp->getContext())); 66 } 67 68 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) { 69 return lookupOrCreateFn(moduleOp, kPrintF32, 70 Float32Type::get(moduleOp->getContext()), 71 LLVM::LLVMVoidType::get(moduleOp->getContext())); 72 } 73 74 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) { 75 return lookupOrCreateFn(moduleOp, kPrintF64, 76 Float64Type::get(moduleOp->getContext()), 77 LLVM::LLVMVoidType::get(moduleOp->getContext())); 78 } 79 80 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) { 81 return lookupOrCreateFn(moduleOp, kPrintOpen, {}, 82 LLVM::LLVMVoidType::get(moduleOp->getContext())); 83 } 84 85 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) { 86 return lookupOrCreateFn(moduleOp, kPrintClose, {}, 87 LLVM::LLVMVoidType::get(moduleOp->getContext())); 88 } 89 90 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) { 91 return lookupOrCreateFn(moduleOp, kPrintComma, {}, 92 LLVM::LLVMVoidType::get(moduleOp->getContext())); 93 } 94 95 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) { 96 return lookupOrCreateFn(moduleOp, kPrintNewline, {}, 97 LLVM::LLVMVoidType::get(moduleOp->getContext())); 98 } 99 100 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp, 101 Type indexType) { 102 return LLVM::lookupOrCreateFn( 103 moduleOp, kMalloc, indexType, 104 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); 105 } 106 107 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, 108 Type indexType) { 109 return LLVM::lookupOrCreateFn( 110 moduleOp, kAlignedAlloc, {indexType, indexType}, 111 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); 112 } 113 114 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) { 115 return LLVM::lookupOrCreateFn( 116 moduleOp, kFree, 117 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)), 118 LLVM::LLVMVoidType::get(moduleOp->getContext())); 119 } 120 121 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp, 122 Type indexType) { 123 return LLVM::lookupOrCreateFn( 124 moduleOp, kGenericAlloc, indexType, 125 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); 126 } 127 128 LLVM::LLVMFuncOp 129 mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp, 130 Type indexType) { 131 return LLVM::lookupOrCreateFn( 132 moduleOp, kGenericAlignedAlloc, {indexType, indexType}, 133 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); 134 } 135 136 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) { 137 return LLVM::lookupOrCreateFn( 138 moduleOp, kGenericFree, 139 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)), 140 LLVM::LLVMVoidType::get(moduleOp->getContext())); 141 } 142 143 LLVM::LLVMFuncOp 144 mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, 145 Type unrankedDescriptorType) { 146 return LLVM::lookupOrCreateFn( 147 moduleOp, kMemRefCopy, 148 ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType}, 149 LLVM::LLVMVoidType::get(moduleOp->getContext())); 150 } 151 152 Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc, 153 LLVM::LLVMFuncOp fn, 154 ValueRange paramTypes, 155 ArrayRef<Type> resultTypes) { 156 return b 157 .create<LLVM::CallOp>(loc, resultTypes, SymbolRefAttr::get(fn), 158 paramTypes) 159 ->getResults(); 160 } 161