1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "GPUOpsLowering.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11 #include "mlir/IR/Builders.h"
12 #include "llvm/Support/FormatVariadic.h"
13
14 using namespace mlir;
15
16 LogicalResult
matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const17 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
18 ConversionPatternRewriter &rewriter) const {
19 Location loc = gpuFuncOp.getLoc();
20
21 SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
22 workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
23 for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
24 Value attribution = en.value();
25
26 auto type = attribution.getType().dyn_cast<MemRefType>();
27 assert(type && type.hasStaticShape() && "unexpected type in attribution");
28
29 uint64_t numElements = type.getNumElements();
30
31 auto elementType =
32 typeConverter->convertType(type.getElementType()).template cast<Type>();
33 auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
34 std::string name = std::string(
35 llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
36 auto globalOp = rewriter.create<LLVM::GlobalOp>(
37 gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
38 LLVM::Linkage::Internal, name, /*value=*/Attribute(),
39 /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace());
40 workgroupBuffers.push_back(globalOp);
41 }
42
43 // Rewrite the original GPU function to an LLVM function.
44 auto funcType = typeConverter->convertType(gpuFuncOp.getFunctionType())
45 .template cast<LLVM::LLVMPointerType>()
46 .getElementType();
47
48 // Remap proper input types.
49 TypeConverter::SignatureConversion signatureConversion(
50 gpuFuncOp.front().getNumArguments());
51 getTypeConverter()->convertFunctionSignature(
52 gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion);
53
54 // Create the new function operation. Only copy those attributes that are
55 // not specific to function modeling.
56 SmallVector<NamedAttribute, 4> attributes;
57 for (const auto &attr : gpuFuncOp->getAttrs()) {
58 if (attr.getName() == SymbolTable::getSymbolAttrName() ||
59 attr.getName() == FunctionOpInterface::getTypeAttrName() ||
60 attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
61 continue;
62 attributes.push_back(attr);
63 }
64 // Add a dialect specific kernel attribute in addition to GPU kernel
65 // attribute. The former is necessary for further translation while the
66 // latter is expected by gpu.launch_func.
67 if (gpuFuncOp.isKernel())
68 attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
69 auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
70 gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
71 LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C,
72 attributes);
73
74 {
75 // Insert operations that correspond to converted workgroup and private
76 // memory attributions to the body of the function. This must operate on
77 // the original function, before the body region is inlined in the new
78 // function to maintain the relation between block arguments and the
79 // parent operation that assigns their semantics.
80 OpBuilder::InsertionGuard guard(rewriter);
81
82 // Rewrite workgroup memory attributions to addresses of global buffers.
83 rewriter.setInsertionPointToStart(&gpuFuncOp.front());
84 unsigned numProperArguments = gpuFuncOp.getNumArguments();
85 auto i32Type = IntegerType::get(rewriter.getContext(), 32);
86
87 Value zero = nullptr;
88 if (!workgroupBuffers.empty())
89 zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
90 rewriter.getI32IntegerAttr(0));
91 for (const auto &en : llvm::enumerate(workgroupBuffers)) {
92 LLVM::GlobalOp global = en.value();
93 Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
94 auto elementType =
95 global.getType().cast<LLVM::LLVMArrayType>().getElementType();
96 Value memory = rewriter.create<LLVM::GEPOp>(
97 loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()),
98 address, ArrayRef<Value>{zero, zero});
99
100 // Build a memref descriptor pointing to the buffer to plug with the
101 // existing memref infrastructure. This may use more registers than
102 // otherwise necessary given that memref sizes are fixed, but we can try
103 // and canonicalize that away later.
104 Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
105 auto type = attribution.getType().cast<MemRefType>();
106 auto descr = MemRefDescriptor::fromStaticShape(
107 rewriter, loc, *getTypeConverter(), type, memory);
108 signatureConversion.remapInput(numProperArguments + en.index(), descr);
109 }
110
111 // Rewrite private memory attributions to alloca'ed buffers.
112 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
113 auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
114 for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
115 Value attribution = en.value();
116 auto type = attribution.getType().cast<MemRefType>();
117 assert(type && type.hasStaticShape() && "unexpected type in attribution");
118
119 // Explicitly drop memory space when lowering private memory
120 // attributions since NVVM models it as `alloca`s in the default
121 // memory space and does not support `alloca`s with addrspace(5).
122 auto ptrType = LLVM::LLVMPointerType::get(
123 typeConverter->convertType(type.getElementType())
124 .template cast<Type>(),
125 allocaAddrSpace);
126 Value numElements = rewriter.create<LLVM::ConstantOp>(
127 gpuFuncOp.getLoc(), int64Ty,
128 rewriter.getI64IntegerAttr(type.getNumElements()));
129 Value allocated = rewriter.create<LLVM::AllocaOp>(
130 gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
131 auto descr = MemRefDescriptor::fromStaticShape(
132 rewriter, loc, *getTypeConverter(), type, allocated);
133 signatureConversion.remapInput(
134 numProperArguments + numWorkgroupAttributions + en.index(), descr);
135 }
136 }
137
138 // Move the region to the new function, update the entry block signature.
139 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
140 llvmFuncOp.end());
141 if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
142 &signatureConversion)))
143 return failure();
144
145 rewriter.eraseOp(gpuFuncOp);
146 return success();
147 }
148
149 static const char formatStringPrefix[] = "printfFormat_";
150
151 template <typename T>
getOrDefineFunction(T & moduleOp,const Location loc,ConversionPatternRewriter & rewriter,StringRef name,LLVM::LLVMFunctionType type)152 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
153 ConversionPatternRewriter &rewriter,
154 StringRef name,
155 LLVM::LLVMFunctionType type) {
156 LLVM::LLVMFuncOp ret;
157 if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
158 ConversionPatternRewriter::InsertionGuard guard(rewriter);
159 rewriter.setInsertionPointToStart(moduleOp.getBody());
160 ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
161 LLVM::Linkage::External);
162 }
163 return ret;
164 }
165
matchAndRewrite(gpu::PrintfOp gpuPrintfOp,gpu::PrintfOpAdaptor adaptor,ConversionPatternRewriter & rewriter) const166 LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
167 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
168 ConversionPatternRewriter &rewriter) const {
169 Location loc = gpuPrintfOp->getLoc();
170
171 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
172 mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
173 mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
174 mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
175 mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
176 // Note: this is the GPUModule op, not the ModuleOp that surrounds it
177 // This ensures that global constants and declarations are placed within
178 // the device code, not the host code
179 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
180
181 auto ocklBegin =
182 getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
183 LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
184 LLVM::LLVMFuncOp ocklAppendArgs;
185 if (!adaptor.args().empty()) {
186 ocklAppendArgs = getOrDefineFunction(
187 moduleOp, loc, rewriter, "__ockl_printf_append_args",
188 LLVM::LLVMFunctionType::get(
189 llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
190 llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
191 }
192 auto ocklAppendStringN = getOrDefineFunction(
193 moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
194 LLVM::LLVMFunctionType::get(
195 llvmI64,
196 {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
197
198 /// Start the printf hostcall
199 Value zeroI64 = rewriter.create<LLVM::ConstantOp>(
200 loc, llvmI64, rewriter.getI64IntegerAttr(0));
201 auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
202 Value printfDesc = printfBeginCall.getResult(0);
203
204 // Create a global constant for the format string
205 unsigned stringNumber = 0;
206 SmallString<16> stringConstName;
207 do {
208 stringConstName.clear();
209 (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
210 } while (moduleOp.lookupSymbol(stringConstName));
211
212 llvm::SmallString<20> formatString(adaptor.format());
213 formatString.push_back('\0'); // Null terminate for C
214 size_t formatStringSize = formatString.size_in_bytes();
215
216 auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
217 LLVM::GlobalOp global;
218 {
219 ConversionPatternRewriter::InsertionGuard guard(rewriter);
220 rewriter.setInsertionPointToStart(moduleOp.getBody());
221 global = rewriter.create<LLVM::GlobalOp>(
222 loc, globalType,
223 /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
224 rewriter.getStringAttr(formatString));
225 }
226
227 // Get a pointer to the format string's first element and pass it to printf()
228 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
229 Value zero = rewriter.create<LLVM::ConstantOp>(
230 loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
231 Value stringStart = rewriter.create<LLVM::GEPOp>(
232 loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
233 Value stringLen = rewriter.create<LLVM::ConstantOp>(
234 loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize));
235
236 Value oneI32 = rewriter.create<LLVM::ConstantOp>(
237 loc, llvmI32, rewriter.getI32IntegerAttr(1));
238 Value zeroI32 = rewriter.create<LLVM::ConstantOp>(
239 loc, llvmI32, rewriter.getI32IntegerAttr(0));
240
241 auto appendFormatCall = rewriter.create<LLVM::CallOp>(
242 loc, ocklAppendStringN,
243 ValueRange{printfDesc, stringStart, stringLen,
244 adaptor.args().empty() ? oneI32 : zeroI32});
245 printfDesc = appendFormatCall.getResult(0);
246
247 // __ockl_printf_append_args takes 7 values per append call
248 constexpr size_t argsPerAppend = 7;
249 size_t nArgs = adaptor.args().size();
250 for (size_t group = 0; group < nArgs; group += argsPerAppend) {
251 size_t bound = std::min(group + argsPerAppend, nArgs);
252 size_t numArgsThisCall = bound - group;
253
254 SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
255 arguments.push_back(printfDesc);
256 arguments.push_back(rewriter.create<LLVM::ConstantOp>(
257 loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall)));
258 for (size_t i = group; i < bound; ++i) {
259 Value arg = adaptor.args()[i];
260 if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
261 if (!floatType.isF64())
262 arg = rewriter.create<LLVM::FPExtOp>(
263 loc, typeConverter->convertType(rewriter.getF64Type()), arg);
264 arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
265 }
266 if (arg.getType().getIntOrFloatBitWidth() != 64)
267 arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
268
269 arguments.push_back(arg);
270 }
271 // Pad out to 7 arguments since the hostcall always needs 7
272 for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
273 arguments.push_back(zeroI64);
274 }
275
276 auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
277 arguments.push_back(isLast);
278 auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
279 printfDesc = call.getResult(0);
280 }
281 rewriter.eraseOp(gpuPrintfOp);
282 return success();
283 }
284
matchAndRewrite(gpu::PrintfOp gpuPrintfOp,gpu::PrintfOpAdaptor adaptor,ConversionPatternRewriter & rewriter) const285 LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
286 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
287 ConversionPatternRewriter &rewriter) const {
288 Location loc = gpuPrintfOp->getLoc();
289
290 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
291 mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace);
292 mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
293
294 // Note: this is the GPUModule op, not the ModuleOp that surrounds it
295 // This ensures that global constants and declarations are placed within
296 // the device code, not the host code
297 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
298
299 auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr},
300 /*isVarArg=*/true);
301 LLVM::LLVMFuncOp printfDecl =
302 getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
303
304 // Create a global constant for the format string
305 unsigned stringNumber = 0;
306 SmallString<16> stringConstName;
307 do {
308 stringConstName.clear();
309 (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
310 } while (moduleOp.lookupSymbol(stringConstName));
311
312 llvm::SmallString<20> formatString(adaptor.format());
313 formatString.push_back('\0'); // Null terminate for C
314 auto globalType =
315 LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
316 LLVM::GlobalOp global;
317 {
318 ConversionPatternRewriter::InsertionGuard guard(rewriter);
319 rewriter.setInsertionPointToStart(moduleOp.getBody());
320 global = rewriter.create<LLVM::GlobalOp>(
321 loc, globalType,
322 /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
323 rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
324 }
325
326 // Get a pointer to the format string's first element
327 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
328 Value zero = rewriter.create<LLVM::ConstantOp>(
329 loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
330 Value stringStart = rewriter.create<LLVM::GEPOp>(
331 loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
332
333 // Construct arguments and function call
334 auto argsRange = adaptor.args();
335 SmallVector<Value, 4> printfArgs;
336 printfArgs.reserve(argsRange.size() + 1);
337 printfArgs.push_back(stringStart);
338 printfArgs.append(argsRange.begin(), argsRange.end());
339
340 rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
341 rewriter.eraseOp(gpuPrintfOp);
342 return success();
343 }
344