1 //===- FunctionCallUtils.cpp - Utilities for C function calls -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements helper functions to call common simple C functions in
10 // LLVMIR (e.g. amon others to support printing and debugging).
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
15 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/OpDefinition.h"
18 #include "mlir/Support/LLVM.h"
19
20 using namespace mlir;
21 using namespace mlir::LLVM;
22
23 /// Helper functions to lookup or create the declaration for commonly used
24 /// external C function calls. The list of functions provided here must be
25 /// implemented separately (e.g. as part of a support runtime library or as
26 /// part of the libc).
27 static constexpr llvm::StringRef kPrintI64 = "printI64";
28 static constexpr llvm::StringRef kPrintU64 = "printU64";
29 static constexpr llvm::StringRef kPrintF32 = "printF32";
30 static constexpr llvm::StringRef kPrintF64 = "printF64";
31 static constexpr llvm::StringRef kPrintOpen = "printOpen";
32 static constexpr llvm::StringRef kPrintClose = "printClose";
33 static constexpr llvm::StringRef kPrintComma = "printComma";
34 static constexpr llvm::StringRef kPrintNewline = "printNewline";
35 static constexpr llvm::StringRef kMalloc = "malloc";
36 static constexpr llvm::StringRef kAlignedAlloc = "aligned_alloc";
37 static constexpr llvm::StringRef kFree = "free";
38 static constexpr llvm::StringRef kGenericAlloc = "_mlir_alloc";
39 static constexpr llvm::StringRef kGenericAlignedAlloc = "_mlir_aligned_alloc";
40 static constexpr llvm::StringRef kGenericFree = "_mlir_free";
41 static constexpr llvm::StringRef kMemRefCopy = "memrefCopy";
42
43 /// Generic print function lookupOrCreate helper.
lookupOrCreateFn(ModuleOp moduleOp,StringRef name,ArrayRef<Type> paramTypes,Type resultType)44 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFn(ModuleOp moduleOp, StringRef name,
45 ArrayRef<Type> paramTypes,
46 Type resultType) {
47 auto func = moduleOp.lookupSymbol<LLVM::LLVMFuncOp>(name);
48 if (func)
49 return func;
50 OpBuilder b(moduleOp.getBodyRegion());
51 return b.create<LLVM::LLVMFuncOp>(
52 moduleOp->getLoc(), name,
53 LLVM::LLVMFunctionType::get(resultType, paramTypes));
54 }
55
lookupOrCreatePrintI64Fn(ModuleOp moduleOp)56 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintI64Fn(ModuleOp moduleOp) {
57 return lookupOrCreateFn(moduleOp, kPrintI64,
58 IntegerType::get(moduleOp->getContext(), 64),
59 LLVM::LLVMVoidType::get(moduleOp->getContext()));
60 }
61
lookupOrCreatePrintU64Fn(ModuleOp moduleOp)62 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintU64Fn(ModuleOp moduleOp) {
63 return lookupOrCreateFn(moduleOp, kPrintU64,
64 IntegerType::get(moduleOp->getContext(), 64),
65 LLVM::LLVMVoidType::get(moduleOp->getContext()));
66 }
67
lookupOrCreatePrintF32Fn(ModuleOp moduleOp)68 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF32Fn(ModuleOp moduleOp) {
69 return lookupOrCreateFn(moduleOp, kPrintF32,
70 Float32Type::get(moduleOp->getContext()),
71 LLVM::LLVMVoidType::get(moduleOp->getContext()));
72 }
73
lookupOrCreatePrintF64Fn(ModuleOp moduleOp)74 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintF64Fn(ModuleOp moduleOp) {
75 return lookupOrCreateFn(moduleOp, kPrintF64,
76 Float64Type::get(moduleOp->getContext()),
77 LLVM::LLVMVoidType::get(moduleOp->getContext()));
78 }
79
lookupOrCreatePrintOpenFn(ModuleOp moduleOp)80 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintOpenFn(ModuleOp moduleOp) {
81 return lookupOrCreateFn(moduleOp, kPrintOpen, {},
82 LLVM::LLVMVoidType::get(moduleOp->getContext()));
83 }
84
lookupOrCreatePrintCloseFn(ModuleOp moduleOp)85 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCloseFn(ModuleOp moduleOp) {
86 return lookupOrCreateFn(moduleOp, kPrintClose, {},
87 LLVM::LLVMVoidType::get(moduleOp->getContext()));
88 }
89
lookupOrCreatePrintCommaFn(ModuleOp moduleOp)90 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintCommaFn(ModuleOp moduleOp) {
91 return lookupOrCreateFn(moduleOp, kPrintComma, {},
92 LLVM::LLVMVoidType::get(moduleOp->getContext()));
93 }
94
lookupOrCreatePrintNewlineFn(ModuleOp moduleOp)95 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreatePrintNewlineFn(ModuleOp moduleOp) {
96 return lookupOrCreateFn(moduleOp, kPrintNewline, {},
97 LLVM::LLVMVoidType::get(moduleOp->getContext()));
98 }
99
lookupOrCreateMallocFn(ModuleOp moduleOp,Type indexType)100 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateMallocFn(ModuleOp moduleOp,
101 Type indexType) {
102 return LLVM::lookupOrCreateFn(
103 moduleOp, kMalloc, indexType,
104 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
105 }
106
lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,Type indexType)107 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateAlignedAllocFn(ModuleOp moduleOp,
108 Type indexType) {
109 return LLVM::lookupOrCreateFn(
110 moduleOp, kAlignedAlloc, {indexType, indexType},
111 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
112 }
113
lookupOrCreateFreeFn(ModuleOp moduleOp)114 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateFreeFn(ModuleOp moduleOp) {
115 return LLVM::lookupOrCreateFn(
116 moduleOp, kFree,
117 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
118 LLVM::LLVMVoidType::get(moduleOp->getContext()));
119 }
120
lookupOrCreateGenericAllocFn(ModuleOp moduleOp,Type indexType)121 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericAllocFn(ModuleOp moduleOp,
122 Type indexType) {
123 return LLVM::lookupOrCreateFn(
124 moduleOp, kGenericAlloc, indexType,
125 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
126 }
127
128 LLVM::LLVMFuncOp
lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,Type indexType)129 mlir::LLVM::lookupOrCreateGenericAlignedAllocFn(ModuleOp moduleOp,
130 Type indexType) {
131 return LLVM::lookupOrCreateFn(
132 moduleOp, kGenericAlignedAlloc, {indexType, indexType},
133 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)));
134 }
135
lookupOrCreateGenericFreeFn(ModuleOp moduleOp)136 LLVM::LLVMFuncOp mlir::LLVM::lookupOrCreateGenericFreeFn(ModuleOp moduleOp) {
137 return LLVM::lookupOrCreateFn(
138 moduleOp, kGenericFree,
139 LLVM::LLVMPointerType::get(IntegerType::get(moduleOp->getContext(), 8)),
140 LLVM::LLVMVoidType::get(moduleOp->getContext()));
141 }
142
143 LLVM::LLVMFuncOp
lookupOrCreateMemRefCopyFn(ModuleOp moduleOp,Type indexType,Type unrankedDescriptorType)144 mlir::LLVM::lookupOrCreateMemRefCopyFn(ModuleOp moduleOp, Type indexType,
145 Type unrankedDescriptorType) {
146 return LLVM::lookupOrCreateFn(
147 moduleOp, kMemRefCopy,
148 ArrayRef<Type>{indexType, unrankedDescriptorType, unrankedDescriptorType},
149 LLVM::LLVMVoidType::get(moduleOp->getContext()));
150 }
151
createLLVMCall(OpBuilder & b,Location loc,LLVM::LLVMFuncOp fn,ValueRange paramTypes,ArrayRef<Type> resultTypes)152 Operation::result_range mlir::LLVM::createLLVMCall(OpBuilder &b, Location loc,
153 LLVM::LLVMFuncOp fn,
154 ValueRange paramTypes,
155 ArrayRef<Type> resultTypes) {
156 return b
157 .create<LLVM::CallOp>(loc, resultTypes, SymbolRefAttr::get(fn),
158 paramTypes)
159 ->getResults();
160 }
161