1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GPUOpsLowering.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/Builders.h"
13 #include "llvm/Support/FormatVariadic.h"
14 
15 using namespace mlir;
16 
17 LogicalResult
18 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
19                                    ConversionPatternRewriter &rewriter) const {
20   Location loc = gpuFuncOp.getLoc();
21 
22   SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
23   workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
24   for (auto en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
25     Value attribution = en.value();
26 
27     auto type = attribution.getType().dyn_cast<MemRefType>();
28     assert(type && type.hasStaticShape() && "unexpected type in attribution");
29 
30     uint64_t numElements = type.getNumElements();
31 
32     auto elementType =
33         typeConverter->convertType(type.getElementType()).template cast<Type>();
34     auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
35     std::string name = std::string(
36         llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
37     auto globalOp = rewriter.create<LLVM::GlobalOp>(
38         gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
39         LLVM::Linkage::Internal, name, /*value=*/Attribute(),
40         /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace());
41     workgroupBuffers.push_back(globalOp);
42   }
43 
44   // Rewrite the original GPU function to an LLVM function.
45   auto funcType = typeConverter->convertType(gpuFuncOp.getType())
46                       .template cast<LLVM::LLVMPointerType>()
47                       .getElementType();
48 
49   // Remap proper input types.
50   TypeConverter::SignatureConversion signatureConversion(
51       gpuFuncOp.front().getNumArguments());
52   getTypeConverter()->convertFunctionSignature(
53       gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion);
54 
55   // Create the new function operation. Only copy those attributes that are
56   // not specific to function modeling.
57   SmallVector<NamedAttribute, 4> attributes;
58   for (const auto &attr : gpuFuncOp->getAttrs()) {
59     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
60         attr.getName() == function_like_impl::getTypeAttrName() ||
61         attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
62       continue;
63     attributes.push_back(attr);
64   }
65   // Add a dialect specific kernel attribute in addition to GPU kernel
66   // attribute. The former is necessary for further translation while the
67   // latter is expected by gpu.launch_func.
68   if (gpuFuncOp.isKernel())
69     attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
70   auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
71       gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
72       LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
73 
74   {
75     // Insert operations that correspond to converted workgroup and private
76     // memory attributions to the body of the function. This must operate on
77     // the original function, before the body region is inlined in the new
78     // function to maintain the relation between block arguments and the
79     // parent operation that assigns their semantics.
80     OpBuilder::InsertionGuard guard(rewriter);
81 
82     // Rewrite workgroup memory attributions to addresses of global buffers.
83     rewriter.setInsertionPointToStart(&gpuFuncOp.front());
84     unsigned numProperArguments = gpuFuncOp.getNumArguments();
85     auto i32Type = IntegerType::get(rewriter.getContext(), 32);
86 
87     Value zero = nullptr;
88     if (!workgroupBuffers.empty())
89       zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
90                                                rewriter.getI32IntegerAttr(0));
91     for (auto en : llvm::enumerate(workgroupBuffers)) {
92       LLVM::GlobalOp global = en.value();
93       Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
94       auto elementType =
95           global.getType().cast<LLVM::LLVMArrayType>().getElementType();
96       Value memory = rewriter.create<LLVM::GEPOp>(
97           loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()),
98           address, ArrayRef<Value>{zero, zero});
99 
100       // Build a memref descriptor pointing to the buffer to plug with the
101       // existing memref infrastructure. This may use more registers than
102       // otherwise necessary given that memref sizes are fixed, but we can try
103       // and canonicalize that away later.
104       Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
105       auto type = attribution.getType().cast<MemRefType>();
106       auto descr = MemRefDescriptor::fromStaticShape(
107           rewriter, loc, *getTypeConverter(), type, memory);
108       signatureConversion.remapInput(numProperArguments + en.index(), descr);
109     }
110 
111     // Rewrite private memory attributions to alloca'ed buffers.
112     unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
113     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
114     for (auto en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
115       Value attribution = en.value();
116       auto type = attribution.getType().cast<MemRefType>();
117       assert(type && type.hasStaticShape() && "unexpected type in attribution");
118 
119       // Explicitly drop memory space when lowering private memory
120       // attributions since NVVM models it as `alloca`s in the default
121       // memory space and does not support `alloca`s with addrspace(5).
122       auto ptrType = LLVM::LLVMPointerType::get(
123           typeConverter->convertType(type.getElementType())
124               .template cast<Type>(),
125           allocaAddrSpace);
126       Value numElements = rewriter.create<LLVM::ConstantOp>(
127           gpuFuncOp.getLoc(), int64Ty,
128           rewriter.getI64IntegerAttr(type.getNumElements()));
129       Value allocated = rewriter.create<LLVM::AllocaOp>(
130           gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
131       auto descr = MemRefDescriptor::fromStaticShape(
132           rewriter, loc, *getTypeConverter(), type, allocated);
133       signatureConversion.remapInput(
134           numProperArguments + numWorkgroupAttributions + en.index(), descr);
135     }
136   }
137 
138   // Move the region to the new function, update the entry block signature.
139   rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
140                               llvmFuncOp.end());
141   if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
142                                          &signatureConversion)))
143     return failure();
144 
145   rewriter.eraseOp(gpuFuncOp);
146   return success();
147 }
148 
149 static const char formatStringPrefix[] = "printfFormat_";
150 
151 template <typename T>
152 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
153                                             ConversionPatternRewriter &rewriter,
154                                             StringRef name,
155                                             LLVM::LLVMFunctionType type) {
156   LLVM::LLVMFuncOp ret;
157   if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
158     ConversionPatternRewriter::InsertionGuard guard(rewriter);
159     rewriter.setInsertionPointToStart(moduleOp.getBody());
160     ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
161                                             LLVM::Linkage::External);
162   }
163   return ret;
164 }
165 
166 LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
167     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
168     ConversionPatternRewriter &rewriter) const {
169   Location loc = gpuPrintfOp->getLoc();
170 
171   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
172   mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
173   mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
174   mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
175   mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
176   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
177   // This ensures that global constants and declarations are placed within
178   // the device code, not the host code
179   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
180 
181   auto ocklBegin =
182       getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
183                           LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
184   LLVM::LLVMFuncOp ocklAppendArgs;
185   if (!adaptor.args().empty()) {
186     ocklAppendArgs = getOrDefineFunction(
187         moduleOp, loc, rewriter, "__ockl_printf_append_args",
188         LLVM::LLVMFunctionType::get(
189             llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
190                       llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
191   }
192   auto ocklAppendStringN = getOrDefineFunction(
193       moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
194       LLVM::LLVMFunctionType::get(
195           llvmI64,
196           {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
197 
198   /// Start the printf hostcall
199   Value zeroI64 = rewriter.create<LLVM::ConstantOp>(
200       loc, llvmI64, rewriter.getI64IntegerAttr(0));
201   auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
202   Value printfDesc = printfBeginCall.getResult(0);
203 
204   // Create a global constant for the format string
205   unsigned stringNumber = 0;
206   SmallString<16> stringConstName;
207   do {
208     stringConstName.clear();
209     (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
210   } while (moduleOp.lookupSymbol(stringConstName));
211 
212   llvm::SmallString<20> formatString(adaptor.format().getValue());
213   formatString.push_back('\0'); // Null terminate for C
214   size_t formatStringSize = formatString.size_in_bytes();
215 
216   auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
217   LLVM::GlobalOp global;
218   {
219     ConversionPatternRewriter::InsertionGuard guard(rewriter);
220     rewriter.setInsertionPointToStart(moduleOp.getBody());
221     global = rewriter.create<LLVM::GlobalOp>(
222         loc, globalType,
223         /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
224         rewriter.getStringAttr(formatString));
225   }
226 
227   // Get a pointer to the format string's first element and pass it to printf()
228   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
229   Value zero = rewriter.create<LLVM::ConstantOp>(
230       loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
231   Value stringStart = rewriter.create<LLVM::GEPOp>(
232       loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
233   Value stringLen = rewriter.create<LLVM::ConstantOp>(
234       loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize));
235 
236   Value oneI32 = rewriter.create<LLVM::ConstantOp>(
237       loc, llvmI32, rewriter.getI32IntegerAttr(1));
238   Value zeroI32 = rewriter.create<LLVM::ConstantOp>(
239       loc, llvmI32, rewriter.getI32IntegerAttr(0));
240 
241   mlir::ValueRange appendFormatArgs = {printfDesc, stringStart, stringLen,
242                                        adaptor.args().empty() ? oneI32
243                                                               : zeroI32};
244   auto appendFormatCall =
245       rewriter.create<LLVM::CallOp>(loc, ocklAppendStringN, appendFormatArgs);
246   printfDesc = appendFormatCall.getResult(0);
247 
248   // __ockl_printf_append_args takes 7 values per append call
249   constexpr size_t argsPerAppend = 7;
250   size_t nArgs = adaptor.args().size();
251   for (size_t group = 0; group < nArgs; group += argsPerAppend) {
252     size_t bound = std::min(group + argsPerAppend, nArgs);
253     size_t numArgsThisCall = bound - group;
254 
255     SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
256     arguments.push_back(printfDesc);
257     arguments.push_back(rewriter.create<LLVM::ConstantOp>(
258         loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall)));
259     for (size_t i = group; i < bound; ++i) {
260       Value arg = adaptor.args()[i];
261       if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
262         if (!floatType.isF64())
263           arg = rewriter.create<LLVM::FPExtOp>(
264               loc, typeConverter->convertType(rewriter.getF64Type()), arg);
265         arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
266       }
267       if (arg.getType().getIntOrFloatBitWidth() != 64)
268         arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
269 
270       arguments.push_back(arg);
271     }
272     // Pad out to 7 arguments since the hostcall always needs 7
273     for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
274       arguments.push_back(zeroI64);
275     }
276 
277     auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
278     arguments.push_back(isLast);
279     auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
280     printfDesc = call.getResult(0);
281   }
282   rewriter.eraseOp(gpuPrintfOp);
283   return success();
284 }
285 
286 LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
287     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
288     ConversionPatternRewriter &rewriter) const {
289   Location loc = gpuPrintfOp->getLoc();
290 
291   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
292   mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace);
293   mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
294 
295   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
296   // This ensures that global constants and declarations are placed within
297   // the device code, not the host code
298   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
299 
300   auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr},
301                                                 /*isVarArg=*/true);
302   LLVM::LLVMFuncOp printfDecl =
303       getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
304 
305   // Create a global constant for the format string
306   unsigned stringNumber = 0;
307   SmallString<16> stringConstName;
308   do {
309     stringConstName.clear();
310     (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
311   } while (moduleOp.lookupSymbol(stringConstName));
312 
313   llvm::SmallString<20> formatString(adaptor.format().getValue());
314   formatString.push_back('\0'); // Null terminate for C
315   auto globalType =
316       LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
317   LLVM::GlobalOp global;
318   {
319     ConversionPatternRewriter::InsertionGuard guard(rewriter);
320     rewriter.setInsertionPointToStart(moduleOp.getBody());
321     global = rewriter.create<LLVM::GlobalOp>(
322         loc, globalType,
323         /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
324         rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
325   }
326 
327   // Get a pointer to the format string's first element
328   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
329   Value zero = rewriter.create<LLVM::ConstantOp>(
330       loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
331   Value stringStart = rewriter.create<LLVM::GEPOp>(
332       loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
333 
334   // Construct arguments and function call
335   auto argsRange = adaptor.args();
336   SmallVector<Value, 4> printfArgs;
337   printfArgs.reserve(argsRange.size() + 1);
338   printfArgs.push_back(stringStart);
339   printfArgs.append(argsRange.begin(), argsRange.end());
340 
341   rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
342   rewriter.eraseOp(gpuPrintfOp);
343   return success();
344 }
345