1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper functions to call common simple C functions in
10 // LLVMIR (e.g. amon others to support printing and debugging).
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/Support/LLVM.h"
19 
20 using namespace mlir;
21 using namespace mlir::LLVM;
22 
23 /// Helper functions to lookup or create the declaration for commonly used
24 /// external C function calls. The list of functions provided here must be
25 /// implemented separately (e.g. as  part of a support runtime library or as
26 /// part of the libc).
27 static constexpr llvm::StringRef kPrintI64 = "printI64";
28 static constexpr llvm::StringRef kPrintU64 = "printU64";
29 static constexpr llvm::StringRef kPrintF32 = "printF32";
30 static constexpr llvm::StringRef kPrintF64 = "printF64";
31 static constexpr llvm::StringRef kPrintOpen = "printOpen";
32 static constexpr llvm::StringRef kPrintClose = "printClose";
33 static constexpr llvm::StringRef kPrintComma = "printComma";
34 static constexpr llvm::StringRef kPrintNewline = "printNewline";
35 static constexpr llvm::StringRef kMalloc = "_mlir_alloc";
36 static constexpr llvm::StringRef kAlignedAlloc = "_mlir_aligned_alloc";
37 static constexpr llvm::StringRef kFree = "_mlir_free";
38 static constexpr llvm::StringRef kAlignedFree = "_mlir_aligned_free";
39 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
40 
41 /// Generic print function lookupOrCreate helper.
42 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
43                                               ArrayRef<Type> paramTypes,
44                                               Type resultType) {
45   auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
46   if (func)
47     return func;
48   OpBuilder b(moduleOp.getBodyRegion());
49   return b.create<LLVM::LLVMFuncOp>(
50       moduleOp->getLoc(), name,
51       LLVM::LLVMFunctionType::get(resultType, paramTypes));
52 }
53 
54 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
55   return lookupOrCreateFn(moduleOp, kPrintI64,
56                           IntegerType::get(moduleOp->getContext(), 64),
57                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
58 }
59 
60 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
61   return lookupOrCreateFn(moduleOp, kPrintU64,
62                           IntegerType::get(moduleOp->getContext(), 64),
63                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
64 }
65 
66 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
67   return lookupOrCreateFn(moduleOp, kPrintF32,
68                           Float32Type::get(moduleOp->getContext()),
69                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
70 }
71 
72 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
73   return lookupOrCreateFn(moduleOp, kPrintF64,
74                           Float64Type::get(moduleOp->getContext()),
75                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
76 }
77 
78 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
79   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
80                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
81 }
82 
83 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
84   return lookupOrCreateFn(moduleOp, kPrintClose, {},
85                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
86 }
87 
88 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
89   return lookupOrCreateFn(moduleOp, kPrintComma, {},
90                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
91 }
92 
93 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
94   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
95                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
96 }
97 
98 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
99                                                     Type indexType) {
100   return LLVM::lookupOrCreateFn(
101       moduleOp, kMalloc, indexType,
102       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
103 }
104 
105 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
106                                                           Type indexType) {
107   return LLVM::lookupOrCreateFn(
108       moduleOp, kAlignedAlloc, {indexType, indexType},
109       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
110 }
111 
112 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
113   return LLVM::lookupOrCreateFn(
114       moduleOp, kFree,
115       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
116       LLVM::LLVMVoidType::get(moduleOp->getContext()));
117 }
118 
119 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedFreeFn(ModuleOp moduleOp) {
120   return LLVM::lookupOrCreateFn(
121       moduleOp, kAlignedFree,
122       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
123       LLVM::LLVMVoidType::get(moduleOp->getContext()));
124 }
125 
126 LLVM::LLVMFuncOp
127 mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
128                                        Type unrankedDescriptorType) {
129   return LLVM::lookupOrCreateFn(
130       moduleOp, kMemRefCopy,
131       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
132       LLVM::LLVMVoidType::get(moduleOp->getContext()));
133 }
134 
135 Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
136                                                    LLVM::LLVMFuncOp fn,
137                                                    ValueRange paramTypes,
138                                                    ArrayRef<Type> resultTypes) {
139   return b
140       .create<LLVM::CallOp>(loc, resultTypes, SymbolRefAttr::get(fn),
141                             paramTypes)
142       ->getResults();
143 }
144