1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements helper functions to call common simple C functions in 10 // LLVMIR (e.g. amon others to support printing and debugging). 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" 15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/OpDefinition.h" 18 #include "mlir/Support/LLVM.h" 19 20 using namespace mlir; 21 using namespace mlir::LLVM; 22 23 /// Helper functions to lookup or create the declaration for commonly used 24 /// external C function calls. The list of functions provided here must be 25 /// implemented separately (e.g. as part of a support runtime library or as 26 /// part of the libc). 27 static constexpr llvm::StringRef kPrintI64 = "printI64"; 28 static constexpr llvm::StringRef kPrintU64 = "printU64"; 29 static constexpr llvm::StringRef kPrintF32 = "printF32"; 30 static constexpr llvm::StringRef kPrintF64 = "printF64"; 31 static constexpr llvm::StringRef kPrintOpen = "printOpen"; 32 static constexpr llvm::StringRef kPrintClose = "printClose"; 33 static constexpr llvm::StringRef kPrintComma = "printComma"; 34 static constexpr llvm::StringRef kPrintNewline = "printNewline"; 35 static constexpr llvm::StringRef kMalloc = "malloc"; 36 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc"; 37 static constexpr llvm::StringRef kFree = "free"; 38 39 /// Generic print function lookupOrCreate helper. 40 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name, 41 ArrayRef<Type> paramTypes, 42 Type resultType) { 43 auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name); 44 if (func) 45 return func; 46 OpBuilder b(moduleOp.getBodyRegion()); 47 return b.create<LLVM::LLVMFuncOp>( 48 moduleOp->getLoc(), name, 49 LLVM::LLVMFunctionType::get(resultType, paramTypes)); 50 } 51 52 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) { 53 return lookupOrCreateFn(moduleOp, kPrintI64, 54 IntegerType::get(moduleOp->getContext(), 64), 55 LLVM::LLVMVoidType::get(moduleOp->getContext())); 56 } 57 58 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) { 59 return lookupOrCreateFn(moduleOp, kPrintU64, 60 IntegerType::get(moduleOp->getContext(), 64), 61 LLVM::LLVMVoidType::get(moduleOp->getContext())); 62 } 63 64 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) { 65 return lookupOrCreateFn(moduleOp, kPrintF32, 66 Float32Type::get(moduleOp->getContext()), 67 LLVM::LLVMVoidType::get(moduleOp->getContext())); 68 } 69 70 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) { 71 return lookupOrCreateFn(moduleOp, kPrintF64, 72 Float64Type::get(moduleOp->getContext()), 73 LLVM::LLVMVoidType::get(moduleOp->getContext())); 74 } 75 76 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) { 77 return lookupOrCreateFn(moduleOp, kPrintOpen, {}, 78 LLVM::LLVMVoidType::get(moduleOp->getContext())); 79 } 80 81 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) { 82 return lookupOrCreateFn(moduleOp, kPrintClose, {}, 83 LLVM::LLVMVoidType::get(moduleOp->getContext())); 84 } 85 86 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) { 87 return lookupOrCreateFn(moduleOp, kPrintComma, {}, 88 LLVM::LLVMVoidType::get(moduleOp->getContext())); 89 } 90 91 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) { 92 return lookupOrCreateFn(moduleOp, kPrintNewline, {}, 93 LLVM::LLVMVoidType::get(moduleOp->getContext())); 94 } 95 96 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp, 97 Type indexType) { 98 return LLVM::lookupOrCreateFn( 99 moduleOp, kMalloc, indexType, 100 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); 101 } 102 103 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp, 104 Type indexType) { 105 return LLVM::lookupOrCreateFn( 106 moduleOp, kAlignedAlloc, {indexType, indexType}, 107 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8))); 108 } 109 110 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) { 111 return LLVM::lookupOrCreateFn( 112 moduleOp, kFree, 113 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)), 114 LLVM::LLVMVoidType::get(moduleOp->getContext())); 115 } 116 117 Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc, 118 LLVM::LLVMFuncOp fn, 119 ValueRange paramTypes, 120 ArrayRef<Type> resultTypes) { 121 return b 122 .create<LLVM::CallOp>(loc, resultTypes, b.getSymbolRefAttr(fn), 123 paramTypes) 124 ->getResults(); 125 } 126