1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements helper functions to call common simple C functions in 10 // LLVMIR (e.g. amon others to support printing and debugging). 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "mlir/Support/LLVM.h" 19 20 using namespace mlir; 21 using namespace mlir::LLVM; 22 23 /// Helper functions to lookup or create the declaration for commonly used 24 /// external C function calls. The list of functions provided here must be 25 /// implemented separately (e.g. as part of a support runtime library or as 26 /// part of the libc). 27 static constexpr llvm::StringRef kPrintI64 = "printI64"; 28 static constexpr llvm::StringRef kPrintU64 = "printU64"; 29 static constexpr llvm::StringRef kPrintF32 = "printF32"; 30 static constexpr llvm::StringRef kPrintF64 = "printF64"; 31 static constexpr llvm::StringRef kPrintOpen = "printOpen"; 32 static constexpr llvm::StringRef kPrintClose = "printClose"; 33 static constexpr llvm::StringRef kPrintComma = "printComma"; 34 static constexpr llvm::StringRef kPrintNewline = "printNewline"; 35 static constexpr llvm::StringRef kMalloc = "_mlir_alloc"; 36 static constexpr llvm::StringRef kAlignedAlloc = "_mlir_aligned_alloc"; 37 static constexpr llvm::StringRef kFree = "_mlir_free"; 38 static constexpr llvm::StringRef kAlignedFree = "_mlir_aligned_free"; 39 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy"; 40 41 /// Generic print function lookupOrCreate helper. 42 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name, 43 ArrayRef<Type> paramTypes, 44 Type resultType) { 45 auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name); 46 if (func) 47 return func; 48 OpBuilder b(moduleOp.getBodyRegion()); 49 return b.create<LLVM::LLVMFuncOp>( 50 moduleOp->getLoc(), name, 51 LLVM::LLVMFunctionType::get(resultType, paramTypes)); 52 } 53 54 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) { 55 return lookupOrCreateFn(moduleOp, kPrintI64, 56 IntegerType::get(moduleOp->getContext(), 64), 57 LLVM::LLVMVoidType::get(moduleOp->getContext())); 58 } 59 60 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) { 61 return lookupOrCreateFn(moduleOp, kPrintU64, 62 IntegerType::get(moduleOp->getContext(), 64), 63 LLVM::LLVMVoidType::get(moduleOp->getContext())); 64 } 65 66 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) { 67 return lookupOrCreateFn(moduleOp, kPrintF32, 68 Float32Type::get(moduleOp->getContext()), 69 LLVM::LLVMVoidType::get(moduleOp->getContext())); 70 } 71 72 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) { 73 return lookupOrCreateFn(moduleOp, kPrintF64, 74 Float64Type::get(moduleOp->getContext()), 75 LLVM::LLVMVoidType::get(moduleOp->getContext())); 76 } 77 78 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) { 79 return lookupOrCreateFn(moduleOp, kPrintOpen, {}, 80 LLVM::LLVMVoidType::get(moduleOp->getContext())); 81 } 82 83 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) { 84 return lookupOrCreateFn(moduleOp, kPrintClose, {}, 85 LLVM::LLVMVoidType::get(moduleOp->getContext())); 86 } 87 88 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) { 89 return lookupOrCreateFn(moduleOp, kPrintComma, {}, 90 LLVM::LLVMVoidType::get(moduleOp->getContext())); 91 } 92 93 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) { 94 return lookupOrCreateFn(moduleOp, kPrintNewline, {}, 95 LLVM::LLVMVoidType::get(moduleOp->getContext())); 96 } 97 98 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp, 99 Type indexType) { 100 return LLVM::lookupOrCreateFn( 101 moduleOp, kMalloc, indexType, 102 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); 103 } 104 105 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, 106 Type indexType) { 107 return LLVM::lookupOrCreateFn( 108 moduleOp, kAlignedAlloc, {indexType, indexType}, 109 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); 110 } 111 112 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) { 113 return LLVM::lookupOrCreateFn( 114 moduleOp, kFree, 115 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)), 116 LLVM::LLVMVoidType::get(moduleOp->getContext())); 117 } 118 119 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedFreeFn(ModuleOp moduleOp) { 120 return LLVM::lookupOrCreateFn( 121 moduleOp, kAlignedFree, 122 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)), 123 LLVM::LLVMVoidType::get(moduleOp->getContext())); 124 } 125 126 LLVM::LLVMFuncOp 127 mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType, 128 Type unrankedDescriptorType) { 129 return LLVM::lookupOrCreateFn( 130 moduleOp, kMemRefCopy, 131 ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType}, 132 LLVM::LLVMVoidType::get(moduleOp->getContext())); 133 } 134 135 Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc, 136 LLVM::LLVMFuncOp fn, 137 ValueRange paramTypes, 138 ArrayRef<Type> resultTypes) { 139 return b 140 .create<LLVM::CallOp>(loc, resultTypes, SymbolRefAttr::get(fn), 141 paramTypes) 142 ->getResults(); 143 } 144