1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper functions to call common simple C functions in
10 // LLVMIR (e.g. amon others to support printing and debugging).
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/Support/LLVM.h"
19 
20 using namespace mlir;
21 using namespace mlir::LLVM;
22 
23 /// Helper functions to lookup or create the declaration for commonly used
24 /// external C function calls. The list of functions provided here must be
25 /// implemented separately (e.g. as  part of a support runtime library or as
26 /// part of the libc).
27 static constexpr llvm::StringRef kPrintI64 = "printI64";
28 static constexpr llvm::StringRef kPrintU64 = "printU64";
29 static constexpr llvm::StringRef kPrintF32 = "printF32";
30 static constexpr llvm::StringRef kPrintF64 = "printF64";
31 static constexpr llvm::StringRef kPrintOpen = "printOpen";
32 static constexpr llvm::StringRef kPrintClose = "printClose";
33 static constexpr llvm::StringRef kPrintComma = "printComma";
34 static constexpr llvm::StringRef kPrintNewline = "printNewline";
35 static constexpr llvm::StringRef kMalloc = "malloc";
36 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
37 static constexpr llvm::StringRef kFree = "free";
38 static constexpr llvm::StringRef kGenericAlloc = "_mlir_alloc";
39 static constexpr llvm::StringRef kGenericAlignedAlloc = "_mlir_aligned_alloc";
40 static constexpr llvm::StringRef kGenericFree = "_mlir_free";
41 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
42 
43 /// Generic print function lookupOrCreate helper.
lookupOrCreateFn(ModuleOp moduleOp,StringRef name,ArrayRef<Type> paramTypes,Type resultType)44 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
45                                               ArrayRef<Type> paramTypes,
46                                               Type resultType) {
47   auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
48   if (func)
49     return func;
50   OpBuilder b(moduleOp.getBodyRegion());
51   return b.create<LLVM::LLVMFuncOp>(
52       moduleOp->getLoc(), name,
53       LLVM::LLVMFunctionType::get(resultType, paramTypes));
54 }
55 
lookupOrCreatePrintI64Fn(ModuleOp moduleOp)56 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
57   return lookupOrCreateFn(moduleOp, kPrintI64,
58                           IntegerType::get(moduleOp->getContext(), 64),
59                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
60 }
61 
lookupOrCreatePrintU64Fn(ModuleOp moduleOp)62 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
63   return lookupOrCreateFn(moduleOp, kPrintU64,
64                           IntegerType::get(moduleOp->getContext(), 64),
65                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
66 }
67 
lookupOrCreatePrintF32Fn(ModuleOp moduleOp)68 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
69   return lookupOrCreateFn(moduleOp, kPrintF32,
70                           Float32Type::get(moduleOp->getContext()),
71                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
72 }
73 
lookupOrCreatePrintF64Fn(ModuleOp moduleOp)74 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
75   return lookupOrCreateFn(moduleOp, kPrintF64,
76                           Float64Type::get(moduleOp->getContext()),
77                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
78 }
79 
lookupOrCreatePrintOpenFn(ModuleOp moduleOp)80 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
81   return lookupOrCreateFn(moduleOp, kPrintOpen, {},
82                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
83 }
84 
lookupOrCreatePrintCloseFn(ModuleOp moduleOp)85 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
86   return lookupOrCreateFn(moduleOp, kPrintClose, {},
87                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
88 }
89 
lookupOrCreatePrintCommaFn(ModuleOp moduleOp)90 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
91   return lookupOrCreateFn(moduleOp, kPrintComma, {},
92                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
93 }
94 
lookupOrCreatePrintNewlineFn(ModuleOp moduleOp)95 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
96   return lookupOrCreateFn(moduleOp, kPrintNewline, {},
97                           LLVM::LLVMVoidType::get(moduleOp->getContext()));
98 }
99 
lookupOrCreateMallocFn(ModuleOp moduleOp,Type indexType)100 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
101                                                     Type indexType) {
102   return LLVM::lookupOrCreateFn(
103       moduleOp, kMalloc, indexType,
104       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
105 }
106 
lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,Type indexType)107 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
108                                                           Type indexType) {
109   return LLVM::lookupOrCreateFn(
110       moduleOp, kAlignedAlloc, {indexType, indexType},
111       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
112 }
113 
lookupOrCreateFreeFn(ModuleOp moduleOp)114 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
115   return LLVM::lookupOrCreateFn(
116       moduleOp, kFree,
117       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
118       LLVM::LLVMVoidType::get(moduleOp->getContext()));
119 }
120 
lookupOrCreateGenericAllocFn(ModuleOp moduleOp,Type indexType)121 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
122                                                           Type indexType) {
123   return LLVM::lookupOrCreateFn(
124       moduleOp, kGenericAlloc, indexType,
125       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
126 }
127 
128 LLVM::LLVMFuncOp
lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,Type indexType)129 mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
130                                                 Type indexType) {
131   return LLVM::lookupOrCreateFn(
132       moduleOp, kGenericAlignedAlloc, {indexType, indexType},
133       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
134 }
135 
lookupOrCreateGenericFreeFn(ModuleOp moduleOp)136 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
137   return LLVM::lookupOrCreateFn(
138       moduleOp, kGenericFree,
139       LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
140       LLVM::LLVMVoidType::get(moduleOp->getContext()));
141 }
142 
143 LLVM::LLVMFuncOp
lookupOrCreateMemRefCopyFn(ModuleOp moduleOp,Type indexType,Type unrankedDescriptorType)144 mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
145                                        Type unrankedDescriptorType) {
146   return LLVM::lookupOrCreateFn(
147       moduleOp, kMemRefCopy,
148       ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
149       LLVM::LLVMVoidType::get(moduleOp->getContext()));
150 }
151 
createLLVMCall(OpBuilder & b,Location loc,LLVM::LLVMFuncOp fn,ValueRange paramTypes,ArrayRef<Type> resultTypes)152 Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
153                                                    LLVM::LLVMFuncOp fn,
154                                                    ValueRange paramTypes,
155                                                    ArrayRef<Type> resultTypes) {
156   return b
157       .create<LLVM::CallOp>(loc, resultTypes, SymbolRefAttr::get(fn),
158                             paramTypes)
159       ->getResults();
160 }
161