1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper functions to call common simple C functions in
10 // LLVMIR (e.g. amon others to support printing and debugging).
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/Support/LLVM.h"
19 
20 using namespace mlir;
21 using namespace mlir::LLVM;
22 
23 /// Helper functions to lookup or create the declaration for commonly used
24 /// external C function calls. The list of functions provided here must be
25 /// implemented separately (e.g. as  part of a support runtime library or as
26 /// part of the libc).
27 static constexpr llvm::StringRef kPrintI64 = "printI64";
28 static constexpr llvm::StringRef kPrintU64 = "printU64";
29 static constexpr llvm::StringRef kPrintF32 = "printF32";
30 static constexpr llvm::StringRef kPrintF64 = "printF64";
31 static constexpr llvm::StringRef kPrintOpen = "printOpen";
32 static constexpr llvm::StringRef kPrintClose = "printClose";
33 static constexpr llvm::StringRef kPrintComma = "printComma";
34 static constexpr llvm::StringRef kPrintNewline = "printNewline";
35 static constexpr llvm::StringRef kMalloc = "malloc";
36 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
37 static constexpr llvm::StringRef kFree = "free";
38 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
39 
40 /// Generic print function lookupOrCreate helper.
41 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
42                                               ArrayRef<Type> paramTypes,
43                                               Type resultType) {
44   auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
45   if (func)
46     return func;
47   OpBuilder b(moduleOp.getBodyRegion());
48   return b.create<LLVM::LLVMFuncOp>(
49       moduleOp->getLoc(), name,
50       LLVM::LLVMFunctionType::get(resultType, paramTypes));
51 }
52 
53 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
54   return lookupOrCreateFn(moduleOp, kPrintI64,
55                           IntegerType::get(moduleOp->getContext(), 64),
56                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
57 }
58 
59 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
60   return lookupOrCreateFn(moduleOp, kPrintU64,
61                           IntegerType::get(moduleOp->getContext(), 64),
62                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
63 }
64 
65 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
66   return lookupOrCreateFn(moduleOp, kPrintF32,
67                           Float32Type::get(moduleOp->getContext()),
68                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
69 }
70 
71 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
72   return lookupOrCreateFn(moduleOp, kPrintF64,
73                           Float64Type::get(moduleOp->getContext()),
74                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
75 }
76 
77 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
78   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
79                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
80 }
81 
82 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
83   return lookupOrCreateFn(moduleOp, kPrintClose, {},
84                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
85 }
86 
87 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
88   return lookupOrCreateFn(moduleOp, kPrintComma, {},
89                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
90 }
91 
92 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
93   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
94                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
95 }
96 
97 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
98                                                     Type indexType) {
99   return LLVM::lookupOrCreateFn(
100       moduleOp, kMalloc, indexType,
101       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
102 }
103 
104 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
105                                                           Type indexType) {
106   return LLVM::lookupOrCreateFn(
107       moduleOp, kAlignedAlloc, {indexType, indexType},
108       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
109 }
110 
111 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
112   return LLVM::lookupOrCreateFn(
113       moduleOp, kFree,
114       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
115       LLVM::LLVMVoidType::get(moduleOp->getContext()));
116 }
117 
118 LLVM::LLVMFuncOp
119 mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
120                                        Type unrankedDescriptorType) {
121   return LLVM::lookupOrCreateFn(
122       moduleOp, kMemRefCopy,
123       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
124       LLVM::LLVMVoidType::get(moduleOp->getContext()));
125 }
126 
127 Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
128                                                    LLVM::LLVMFuncOp fn,
129                                                    ValueRange paramTypes,
130                                                    ArrayRef<Type> resultTypes) {
131   return b
132       .create<LLVM::CallOp>(loc, resultTypes, SymbolRefAttr::get(fn),
133                             paramTypes)
134       ->getResults();
135 }
136