1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper functions to call common simple C functions in
10 // LLVMIR (e.g. amon others to support printing and debugging).
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/Support/LLVM.h"
19 
20 using namespace mlir;
21 using namespace mlir::LLVM;
22 
23 /// Helper functions to lookup or create the declaration for commonly used
24 /// external C function calls. The list of functions provided here must be
25 /// implemented separately (e.g. as  part of a support runtime library or as
26 /// part of the libc).
27 static constexpr llvm::StringRef kPrintI64 = "printI64";
28 static constexpr llvm::StringRef kPrintU64 = "printU64";
29 static constexpr llvm::StringRef kPrintF32 = "printF32";
30 static constexpr llvm::StringRef kPrintF64 = "printF64";
31 static constexpr llvm::StringRef kPrintOpen = "printOpen";
32 static constexpr llvm::StringRef kPrintClose = "printClose";
33 static constexpr llvm::StringRef kPrintComma = "printComma";
34 static constexpr llvm::StringRef kPrintNewline = "printNewline";
35 static constexpr llvm::StringRef kMalloc = "malloc";
36 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
37 static constexpr llvm::StringRef kFree = "free";
38 
39 /// Generic print function lookupOrCreate helper.
40 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
41                                               ArrayRef<Type> paramTypes,
42                                               Type resultType) {
43   auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
44   if (func)
45     return func;
46   OpBuilder b(moduleOp.getBodyRegion());
47   return b.create<LLVM::LLVMFuncOp>(
48       moduleOp->getLoc(), name,
49       LLVM::LLVMFunctionType::get(resultType, paramTypes));
50 }
51 
52 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
53   return lookupOrCreateFn(moduleOp, kPrintI64,
54                           IntegerType::get(moduleOp->getContext(), 64),
55                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
56 }
57 
58 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
59   return lookupOrCreateFn(moduleOp, kPrintU64,
60                           IntegerType::get(moduleOp->getContext(), 64),
61                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
62 }
63 
64 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
65   return lookupOrCreateFn(moduleOp, kPrintF32,
66                           Float32Type::get(moduleOp->getContext()),
67                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
68 }
69 
70 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
71   return lookupOrCreateFn(moduleOp, kPrintF64,
72                           Float64Type::get(moduleOp->getContext()),
73                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
74 }
75 
76 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
77   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
78                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
79 }
80 
81 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
82   return lookupOrCreateFn(moduleOp, kPrintClose, {},
83                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
84 }
85 
86 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
87   return lookupOrCreateFn(moduleOp, kPrintComma, {},
88                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
89 }
90 
91 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
92   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
93                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
94 }
95 
96 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
97                                                     Type indexType) {
98   return LLVM::lookupOrCreateFn(
99       moduleOp, kMalloc, indexType,
100       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
101 }
102 
103 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
104                                                           Type indexType) {
105   return LLVM::lookupOrCreateFn(
106       moduleOp, kAlignedAlloc, {indexType, indexType},
107       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
108 }
109 
110 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
111   return LLVM::lookupOrCreateFn(
112       moduleOp, kFree,
113       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
114       LLVM::LLVMVoidType::get(moduleOp->getContext()));
115 }
116 
117 Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
118                                                    LLVM::LLVMFuncOp fn,
119                                                    ValueRange paramTypes,
120                                                    ArrayRef<Type> resultTypes) {
121   return b
122       .create<LLVM::CallOp>(loc, resultTypes, b.getSymbolRefAttr(fn),
123                             paramTypes)
124       ->getResults();
125 }
126