1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "GPUOpsLowering.h"
10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
11 #include "mlir/IR/Builders.h"
12 #include "llvm/Support/FormatVariadic.h"
13 
14 using namespace mlir;
15 
16 LogicalResult
17 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
18                                    ConversionPatternRewriter &rewriter) const {
19   Location loc = gpuFuncOp.getLoc();
20 
21   SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
22   workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
23   for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
24     Value attribution = en.value();
25 
26     auto type = attribution.getType().dyn_cast<MemRefType>();
27     assert(type && type.hasStaticShape() && "unexpected type in attribution");
28 
29     uint64_t numElements = type.getNumElements();
30 
31     auto elementType =
32         typeConverter->convertType(type.getElementType()).template cast<Type>();
33     auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
34     std::string name = std::string(
35         llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
36     auto globalOp = rewriter.create<LLVM::GlobalOp>(
37         gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
38         LLVM::Linkage::Internal, name, /*value=*/Attribute(),
39         /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace());
40     workgroupBuffers.push_back(globalOp);
41   }
42 
43   // Rewrite the original GPU function to an LLVM function.
44   auto funcType = typeConverter->convertType(gpuFuncOp.getFunctionType())
45                       .template cast<LLVM::LLVMPointerType>()
46                       .getElementType();
47 
48   // Remap proper input types.
49   TypeConverter::SignatureConversion signatureConversion(
50       gpuFuncOp.front().getNumArguments());
51   getTypeConverter()->convertFunctionSignature(
52       gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion);
53 
54   // Create the new function operation. Only copy those attributes that are
55   // not specific to function modeling.
56   SmallVector<NamedAttribute, 4> attributes;
57   for (const auto &attr : gpuFuncOp->getAttrs()) {
58     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
59         attr.getName() == FunctionOpInterface::getTypeAttrName() ||
60         attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
61       continue;
62     attributes.push_back(attr);
63   }
64   // Add a dialect specific kernel attribute in addition to GPU kernel
65   // attribute. The former is necessary for further translation while the
66   // latter is expected by gpu.launch_func.
67   if (gpuFuncOp.isKernel())
68     attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
69   auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
70       gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
71       LLVM::Linkage::External, /*dsoLocal*/ false, attributes);
72 
73   {
74     // Insert operations that correspond to converted workgroup and private
75     // memory attributions to the body of the function. This must operate on
76     // the original function, before the body region is inlined in the new
77     // function to maintain the relation between block arguments and the
78     // parent operation that assigns their semantics.
79     OpBuilder::InsertionGuard guard(rewriter);
80 
81     // Rewrite workgroup memory attributions to addresses of global buffers.
82     rewriter.setInsertionPointToStart(&gpuFuncOp.front());
83     unsigned numProperArguments = gpuFuncOp.getNumArguments();
84     auto i32Type = IntegerType::get(rewriter.getContext(), 32);
85 
86     Value zero = nullptr;
87     if (!workgroupBuffers.empty())
88       zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
89                                                rewriter.getI32IntegerAttr(0));
90     for (const auto &en : llvm::enumerate(workgroupBuffers)) {
91       LLVM::GlobalOp global = en.value();
92       Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
93       auto elementType =
94           global.getType().cast<LLVM::LLVMArrayType>().getElementType();
95       Value memory = rewriter.create<LLVM::GEPOp>(
96           loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()),
97           address, ArrayRef<Value>{zero, zero});
98 
99       // Build a memref descriptor pointing to the buffer to plug with the
100       // existing memref infrastructure. This may use more registers than
101       // otherwise necessary given that memref sizes are fixed, but we can try
102       // and canonicalize that away later.
103       Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
104       auto type = attribution.getType().cast<MemRefType>();
105       auto descr = MemRefDescriptor::fromStaticShape(
106           rewriter, loc, *getTypeConverter(), type, memory);
107       signatureConversion.remapInput(numProperArguments + en.index(), descr);
108     }
109 
110     // Rewrite private memory attributions to alloca'ed buffers.
111     unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
112     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
113     for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
114       Value attribution = en.value();
115       auto type = attribution.getType().cast<MemRefType>();
116       assert(type && type.hasStaticShape() && "unexpected type in attribution");
117 
118       // Explicitly drop memory space when lowering private memory
119       // attributions since NVVM models it as `alloca`s in the default
120       // memory space and does not support `alloca`s with addrspace(5).
121       auto ptrType = LLVM::LLVMPointerType::get(
122           typeConverter->convertType(type.getElementType())
123               .template cast<Type>(),
124           allocaAddrSpace);
125       Value numElements = rewriter.create<LLVM::ConstantOp>(
126           gpuFuncOp.getLoc(), int64Ty,
127           rewriter.getI64IntegerAttr(type.getNumElements()));
128       Value allocated = rewriter.create<LLVM::AllocaOp>(
129           gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
130       auto descr = MemRefDescriptor::fromStaticShape(
131           rewriter, loc, *getTypeConverter(), type, allocated);
132       signatureConversion.remapInput(
133           numProperArguments + numWorkgroupAttributions + en.index(), descr);
134     }
135   }
136 
137   // Move the region to the new function, update the entry block signature.
138   rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
139                               llvmFuncOp.end());
140   if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
141                                          &signatureConversion)))
142     return failure();
143 
144   rewriter.eraseOp(gpuFuncOp);
145   return success();
146 }
147 
148 static const char formatStringPrefix[] = "printfFormat_";
149 
150 template <typename T>
151 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
152                                             ConversionPatternRewriter &rewriter,
153                                             StringRef name,
154                                             LLVM::LLVMFunctionType type) {
155   LLVM::LLVMFuncOp ret;
156   if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
157     ConversionPatternRewriter::InsertionGuard guard(rewriter);
158     rewriter.setInsertionPointToStart(moduleOp.getBody());
159     ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
160                                             LLVM::Linkage::External);
161   }
162   return ret;
163 }
164 
165 LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
166     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
167     ConversionPatternRewriter &rewriter) const {
168   Location loc = gpuPrintfOp->getLoc();
169 
170   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
171   mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
172   mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
173   mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
174   mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
175   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
176   // This ensures that global constants and declarations are placed within
177   // the device code, not the host code
178   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
179 
180   auto ocklBegin =
181       getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
182                           LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
183   LLVM::LLVMFuncOp ocklAppendArgs;
184   if (!adaptor.args().empty()) {
185     ocklAppendArgs = getOrDefineFunction(
186         moduleOp, loc, rewriter, "__ockl_printf_append_args",
187         LLVM::LLVMFunctionType::get(
188             llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
189                       llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
190   }
191   auto ocklAppendStringN = getOrDefineFunction(
192       moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
193       LLVM::LLVMFunctionType::get(
194           llvmI64,
195           {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
196 
197   /// Start the printf hostcall
198   Value zeroI64 = rewriter.create<LLVM::ConstantOp>(
199       loc, llvmI64, rewriter.getI64IntegerAttr(0));
200   auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
201   Value printfDesc = printfBeginCall.getResult(0);
202 
203   // Create a global constant for the format string
204   unsigned stringNumber = 0;
205   SmallString<16> stringConstName;
206   do {
207     stringConstName.clear();
208     (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
209   } while (moduleOp.lookupSymbol(stringConstName));
210 
211   llvm::SmallString<20> formatString(adaptor.format());
212   formatString.push_back('\0'); // Null terminate for C
213   size_t formatStringSize = formatString.size_in_bytes();
214 
215   auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
216   LLVM::GlobalOp global;
217   {
218     ConversionPatternRewriter::InsertionGuard guard(rewriter);
219     rewriter.setInsertionPointToStart(moduleOp.getBody());
220     global = rewriter.create<LLVM::GlobalOp>(
221         loc, globalType,
222         /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
223         rewriter.getStringAttr(formatString));
224   }
225 
226   // Get a pointer to the format string's first element and pass it to printf()
227   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
228   Value zero = rewriter.create<LLVM::ConstantOp>(
229       loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
230   Value stringStart = rewriter.create<LLVM::GEPOp>(
231       loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
232   Value stringLen = rewriter.create<LLVM::ConstantOp>(
233       loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize));
234 
235   Value oneI32 = rewriter.create<LLVM::ConstantOp>(
236       loc, llvmI32, rewriter.getI32IntegerAttr(1));
237   Value zeroI32 = rewriter.create<LLVM::ConstantOp>(
238       loc, llvmI32, rewriter.getI32IntegerAttr(0));
239 
240   auto appendFormatCall = rewriter.create<LLVM::CallOp>(
241       loc, ocklAppendStringN,
242       ValueRange{printfDesc, stringStart, stringLen,
243                  adaptor.args().empty() ? oneI32 : zeroI32});
244   printfDesc = appendFormatCall.getResult(0);
245 
246   // __ockl_printf_append_args takes 7 values per append call
247   constexpr size_t argsPerAppend = 7;
248   size_t nArgs = adaptor.args().size();
249   for (size_t group = 0; group < nArgs; group += argsPerAppend) {
250     size_t bound = std::min(group + argsPerAppend, nArgs);
251     size_t numArgsThisCall = bound - group;
252 
253     SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
254     arguments.push_back(printfDesc);
255     arguments.push_back(rewriter.create<LLVM::ConstantOp>(
256         loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall)));
257     for (size_t i = group; i < bound; ++i) {
258       Value arg = adaptor.args()[i];
259       if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
260         if (!floatType.isF64())
261           arg = rewriter.create<LLVM::FPExtOp>(
262               loc, typeConverter->convertType(rewriter.getF64Type()), arg);
263         arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
264       }
265       if (arg.getType().getIntOrFloatBitWidth() != 64)
266         arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
267 
268       arguments.push_back(arg);
269     }
270     // Pad out to 7 arguments since the hostcall always needs 7
271     for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
272       arguments.push_back(zeroI64);
273     }
274 
275     auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
276     arguments.push_back(isLast);
277     auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
278     printfDesc = call.getResult(0);
279   }
280   rewriter.eraseOp(gpuPrintfOp);
281   return success();
282 }
283 
284 LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
285     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
286     ConversionPatternRewriter &rewriter) const {
287   Location loc = gpuPrintfOp->getLoc();
288 
289   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
290   mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace);
291   mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
292 
293   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
294   // This ensures that global constants and declarations are placed within
295   // the device code, not the host code
296   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
297 
298   auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr},
299                                                 /*isVarArg=*/true);
300   LLVM::LLVMFuncOp printfDecl =
301       getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
302 
303   // Create a global constant for the format string
304   unsigned stringNumber = 0;
305   SmallString<16> stringConstName;
306   do {
307     stringConstName.clear();
308     (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
309   } while (moduleOp.lookupSymbol(stringConstName));
310 
311   llvm::SmallString<20> formatString(adaptor.format());
312   formatString.push_back('\0'); // Null terminate for C
313   auto globalType =
314       LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
315   LLVM::GlobalOp global;
316   {
317     ConversionPatternRewriter::InsertionGuard guard(rewriter);
318     rewriter.setInsertionPointToStart(moduleOp.getBody());
319     global = rewriter.create<LLVM::GlobalOp>(
320         loc, globalType,
321         /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
322         rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
323   }
324 
325   // Get a pointer to the format string's first element
326   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
327   Value zero = rewriter.create<LLVM::ConstantOp>(
328       loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
329   Value stringStart = rewriter.create<LLVM::GEPOp>(
330       loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
331 
332   // Construct arguments and function call
333   auto argsRange = adaptor.args();
334   SmallVector<Value, 4> printfArgs;
335   printfArgs.reserve(argsRange.size() + 1);
336   printfArgs.push_back(stringStart);
337   printfArgs.append(argsRange.begin(), argsRange.end());
338 
339   rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
340   rewriter.eraseOp(gpuPrintfOp);
341   return success();
342 }
343