14c4876c3SAlex Zinenko //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
24c4876c3SAlex Zinenko //
34c4876c3SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44c4876c3SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
54c4876c3SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64c4876c3SAlex Zinenko //
74c4876c3SAlex Zinenko //===----------------------------------------------------------------------===//
84c4876c3SAlex Zinenko 
94c4876c3SAlex Zinenko #include "GPUOpsLowering.h"
10e1da6291SKrzysztof Drewniak #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
114c4876c3SAlex Zinenko #include "mlir/IR/Builders.h"
124c4876c3SAlex Zinenko #include "llvm/Support/FormatVariadic.h"
134c4876c3SAlex Zinenko 
144c4876c3SAlex Zinenko using namespace mlir;
154c4876c3SAlex Zinenko 
164c4876c3SAlex Zinenko LogicalResult
matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const17ef976337SRiver Riddle GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
184c4876c3SAlex Zinenko                                    ConversionPatternRewriter &rewriter) const {
194c4876c3SAlex Zinenko   Location loc = gpuFuncOp.getLoc();
204c4876c3SAlex Zinenko 
214c4876c3SAlex Zinenko   SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
224c4876c3SAlex Zinenko   workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
23e4853be2SMehdi Amini   for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
244c4876c3SAlex Zinenko     Value attribution = en.value();
254c4876c3SAlex Zinenko 
264c4876c3SAlex Zinenko     auto type = attribution.getType().dyn_cast<MemRefType>();
274c4876c3SAlex Zinenko     assert(type && type.hasStaticShape() && "unexpected type in attribution");
284c4876c3SAlex Zinenko 
294c4876c3SAlex Zinenko     uint64_t numElements = type.getNumElements();
304c4876c3SAlex Zinenko 
314c4876c3SAlex Zinenko     auto elementType =
324c4876c3SAlex Zinenko         typeConverter->convertType(type.getElementType()).template cast<Type>();
334c4876c3SAlex Zinenko     auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
344c4876c3SAlex Zinenko     std::string name = std::string(
354c4876c3SAlex Zinenko         llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
364c4876c3SAlex Zinenko     auto globalOp = rewriter.create<LLVM::GlobalOp>(
374c4876c3SAlex Zinenko         gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
384c4876c3SAlex Zinenko         LLVM::Linkage::Internal, name, /*value=*/Attribute(),
399a0ea599SDumitru Potop         /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace());
404c4876c3SAlex Zinenko     workgroupBuffers.push_back(globalOp);
414c4876c3SAlex Zinenko   }
424c4876c3SAlex Zinenko 
434c4876c3SAlex Zinenko   // Rewrite the original GPU function to an LLVM function.
444a3460a7SRiver Riddle   auto funcType = typeConverter->convertType(gpuFuncOp.getFunctionType())
454c4876c3SAlex Zinenko                       .template cast<LLVM::LLVMPointerType>()
464c4876c3SAlex Zinenko                       .getElementType();
474c4876c3SAlex Zinenko 
484c4876c3SAlex Zinenko   // Remap proper input types.
494c4876c3SAlex Zinenko   TypeConverter::SignatureConversion signatureConversion(
504c4876c3SAlex Zinenko       gpuFuncOp.front().getNumArguments());
514c4876c3SAlex Zinenko   getTypeConverter()->convertFunctionSignature(
524a3460a7SRiver Riddle       gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion);
534c4876c3SAlex Zinenko 
544c4876c3SAlex Zinenko   // Create the new function operation. Only copy those attributes that are
554c4876c3SAlex Zinenko   // not specific to function modeling.
564c4876c3SAlex Zinenko   SmallVector<NamedAttribute, 4> attributes;
5756774bddSMarius Brehler   for (const auto &attr : gpuFuncOp->getAttrs()) {
580c7890c8SRiver Riddle     if (attr.getName() == SymbolTable::getSymbolAttrName() ||
597ceffae1SRiver Riddle         attr.getName() == FunctionOpInterface::getTypeAttrName() ||
600c7890c8SRiver Riddle         attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
614c4876c3SAlex Zinenko       continue;
624c4876c3SAlex Zinenko     attributes.push_back(attr);
634c4876c3SAlex Zinenko   }
644c4876c3SAlex Zinenko   // Add a dialect specific kernel attribute in addition to GPU kernel
654c4876c3SAlex Zinenko   // attribute. The former is necessary for further translation while the
664c4876c3SAlex Zinenko   // latter is expected by gpu.launch_func.
674c4876c3SAlex Zinenko   if (gpuFuncOp.isKernel())
684c4876c3SAlex Zinenko     attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
694c4876c3SAlex Zinenko   auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
704c4876c3SAlex Zinenko       gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
71*0252357bSAlexander Batashev       LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C,
72*0252357bSAlexander Batashev       attributes);
734c4876c3SAlex Zinenko 
744c4876c3SAlex Zinenko   {
754c4876c3SAlex Zinenko     // Insert operations that correspond to converted workgroup and private
764c4876c3SAlex Zinenko     // memory attributions to the body of the function. This must operate on
774c4876c3SAlex Zinenko     // the original function, before the body region is inlined in the new
784c4876c3SAlex Zinenko     // function to maintain the relation between block arguments and the
794c4876c3SAlex Zinenko     // parent operation that assigns their semantics.
804c4876c3SAlex Zinenko     OpBuilder::InsertionGuard guard(rewriter);
814c4876c3SAlex Zinenko 
824c4876c3SAlex Zinenko     // Rewrite workgroup memory attributions to addresses of global buffers.
834c4876c3SAlex Zinenko     rewriter.setInsertionPointToStart(&gpuFuncOp.front());
844c4876c3SAlex Zinenko     unsigned numProperArguments = gpuFuncOp.getNumArguments();
854c4876c3SAlex Zinenko     auto i32Type = IntegerType::get(rewriter.getContext(), 32);
864c4876c3SAlex Zinenko 
874c4876c3SAlex Zinenko     Value zero = nullptr;
884c4876c3SAlex Zinenko     if (!workgroupBuffers.empty())
894c4876c3SAlex Zinenko       zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
904c4876c3SAlex Zinenko                                                rewriter.getI32IntegerAttr(0));
91e4853be2SMehdi Amini     for (const auto &en : llvm::enumerate(workgroupBuffers)) {
924c4876c3SAlex Zinenko       LLVM::GlobalOp global = en.value();
934c4876c3SAlex Zinenko       Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
944c4876c3SAlex Zinenko       auto elementType =
954c4876c3SAlex Zinenko           global.getType().cast<LLVM::LLVMArrayType>().getElementType();
964c4876c3SAlex Zinenko       Value memory = rewriter.create<LLVM::GEPOp>(
97cfb72fd3SJacques Pienaar           loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()),
984c4876c3SAlex Zinenko           address, ArrayRef<Value>{zero, zero});
994c4876c3SAlex Zinenko 
1004c4876c3SAlex Zinenko       // Build a memref descriptor pointing to the buffer to plug with the
1014c4876c3SAlex Zinenko       // existing memref infrastructure. This may use more registers than
1024c4876c3SAlex Zinenko       // otherwise necessary given that memref sizes are fixed, but we can try
1034c4876c3SAlex Zinenko       // and canonicalize that away later.
1044c4876c3SAlex Zinenko       Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
1054c4876c3SAlex Zinenko       auto type = attribution.getType().cast<MemRefType>();
1064c4876c3SAlex Zinenko       auto descr = MemRefDescriptor::fromStaticShape(
1074c4876c3SAlex Zinenko           rewriter, loc, *getTypeConverter(), type, memory);
1084c4876c3SAlex Zinenko       signatureConversion.remapInput(numProperArguments + en.index(), descr);
1094c4876c3SAlex Zinenko     }
1104c4876c3SAlex Zinenko 
1114c4876c3SAlex Zinenko     // Rewrite private memory attributions to alloca'ed buffers.
1124c4876c3SAlex Zinenko     unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
1134c4876c3SAlex Zinenko     auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
114e4853be2SMehdi Amini     for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
1154c4876c3SAlex Zinenko       Value attribution = en.value();
1164c4876c3SAlex Zinenko       auto type = attribution.getType().cast<MemRefType>();
1174c4876c3SAlex Zinenko       assert(type && type.hasStaticShape() && "unexpected type in attribution");
1184c4876c3SAlex Zinenko 
1194c4876c3SAlex Zinenko       // Explicitly drop memory space when lowering private memory
1204c4876c3SAlex Zinenko       // attributions since NVVM models it as `alloca`s in the default
1214c4876c3SAlex Zinenko       // memory space and does not support `alloca`s with addrspace(5).
1224c4876c3SAlex Zinenko       auto ptrType = LLVM::LLVMPointerType::get(
1234c4876c3SAlex Zinenko           typeConverter->convertType(type.getElementType())
1244c4876c3SAlex Zinenko               .template cast<Type>(),
1254c4876c3SAlex Zinenko           allocaAddrSpace);
1264c4876c3SAlex Zinenko       Value numElements = rewriter.create<LLVM::ConstantOp>(
1274c4876c3SAlex Zinenko           gpuFuncOp.getLoc(), int64Ty,
1284c4876c3SAlex Zinenko           rewriter.getI64IntegerAttr(type.getNumElements()));
1294c4876c3SAlex Zinenko       Value allocated = rewriter.create<LLVM::AllocaOp>(
1304c4876c3SAlex Zinenko           gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
1314c4876c3SAlex Zinenko       auto descr = MemRefDescriptor::fromStaticShape(
1324c4876c3SAlex Zinenko           rewriter, loc, *getTypeConverter(), type, allocated);
1334c4876c3SAlex Zinenko       signatureConversion.remapInput(
1344c4876c3SAlex Zinenko           numProperArguments + numWorkgroupAttributions + en.index(), descr);
1354c4876c3SAlex Zinenko     }
1364c4876c3SAlex Zinenko   }
1374c4876c3SAlex Zinenko 
1384c4876c3SAlex Zinenko   // Move the region to the new function, update the entry block signature.
1394c4876c3SAlex Zinenko   rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
1404c4876c3SAlex Zinenko                               llvmFuncOp.end());
1414c4876c3SAlex Zinenko   if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
1424c4876c3SAlex Zinenko                                          &signatureConversion)))
1434c4876c3SAlex Zinenko     return failure();
1444c4876c3SAlex Zinenko 
1454c4876c3SAlex Zinenko   rewriter.eraseOp(gpuFuncOp);
1464c4876c3SAlex Zinenko   return success();
1474c4876c3SAlex Zinenko }
148e1da6291SKrzysztof Drewniak 
149e1da6291SKrzysztof Drewniak static const char formatStringPrefix[] = "printfFormat_";
150e1da6291SKrzysztof Drewniak 
151e1da6291SKrzysztof Drewniak template <typename T>
getOrDefineFunction(T & moduleOp,const Location loc,ConversionPatternRewriter & rewriter,StringRef name,LLVM::LLVMFunctionType type)152e1da6291SKrzysztof Drewniak static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
153e1da6291SKrzysztof Drewniak                                             ConversionPatternRewriter &rewriter,
154e1da6291SKrzysztof Drewniak                                             StringRef name,
155e1da6291SKrzysztof Drewniak                                             LLVM::LLVMFunctionType type) {
156e1da6291SKrzysztof Drewniak   LLVM::LLVMFuncOp ret;
157e1da6291SKrzysztof Drewniak   if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
158e1da6291SKrzysztof Drewniak     ConversionPatternRewriter::InsertionGuard guard(rewriter);
159e1da6291SKrzysztof Drewniak     rewriter.setInsertionPointToStart(moduleOp.getBody());
160e1da6291SKrzysztof Drewniak     ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
161e1da6291SKrzysztof Drewniak                                             LLVM::Linkage::External);
162e1da6291SKrzysztof Drewniak   }
163e1da6291SKrzysztof Drewniak   return ret;
164e1da6291SKrzysztof Drewniak }
165e1da6291SKrzysztof Drewniak 
matchAndRewrite(gpu::PrintfOp gpuPrintfOp,gpu::PrintfOpAdaptor adaptor,ConversionPatternRewriter & rewriter) const166e1da6291SKrzysztof Drewniak LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
167e1da6291SKrzysztof Drewniak     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
168e1da6291SKrzysztof Drewniak     ConversionPatternRewriter &rewriter) const {
169e1da6291SKrzysztof Drewniak   Location loc = gpuPrintfOp->getLoc();
170e1da6291SKrzysztof Drewniak 
171e1da6291SKrzysztof Drewniak   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
172e1da6291SKrzysztof Drewniak   mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
173e1da6291SKrzysztof Drewniak   mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
174e1da6291SKrzysztof Drewniak   mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
175e1da6291SKrzysztof Drewniak   mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
176e1da6291SKrzysztof Drewniak   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
177e1da6291SKrzysztof Drewniak   // This ensures that global constants and declarations are placed within
178e1da6291SKrzysztof Drewniak   // the device code, not the host code
179e1da6291SKrzysztof Drewniak   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
180e1da6291SKrzysztof Drewniak 
181e1da6291SKrzysztof Drewniak   auto ocklBegin =
182e1da6291SKrzysztof Drewniak       getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
183e1da6291SKrzysztof Drewniak                           LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
184e1da6291SKrzysztof Drewniak   LLVM::LLVMFuncOp ocklAppendArgs;
185e1da6291SKrzysztof Drewniak   if (!adaptor.args().empty()) {
186e1da6291SKrzysztof Drewniak     ocklAppendArgs = getOrDefineFunction(
187e1da6291SKrzysztof Drewniak         moduleOp, loc, rewriter, "__ockl_printf_append_args",
188e1da6291SKrzysztof Drewniak         LLVM::LLVMFunctionType::get(
189e1da6291SKrzysztof Drewniak             llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
190e1da6291SKrzysztof Drewniak                       llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
191e1da6291SKrzysztof Drewniak   }
192e1da6291SKrzysztof Drewniak   auto ocklAppendStringN = getOrDefineFunction(
193e1da6291SKrzysztof Drewniak       moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
194e1da6291SKrzysztof Drewniak       LLVM::LLVMFunctionType::get(
195e1da6291SKrzysztof Drewniak           llvmI64,
196e1da6291SKrzysztof Drewniak           {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
197e1da6291SKrzysztof Drewniak 
198e1da6291SKrzysztof Drewniak   /// Start the printf hostcall
199e1da6291SKrzysztof Drewniak   Value zeroI64 = rewriter.create<LLVM::ConstantOp>(
200e1da6291SKrzysztof Drewniak       loc, llvmI64, rewriter.getI64IntegerAttr(0));
201e1da6291SKrzysztof Drewniak   auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
202e1da6291SKrzysztof Drewniak   Value printfDesc = printfBeginCall.getResult(0);
203e1da6291SKrzysztof Drewniak 
204e1da6291SKrzysztof Drewniak   // Create a global constant for the format string
205e1da6291SKrzysztof Drewniak   unsigned stringNumber = 0;
206e1da6291SKrzysztof Drewniak   SmallString<16> stringConstName;
207e1da6291SKrzysztof Drewniak   do {
208e1da6291SKrzysztof Drewniak     stringConstName.clear();
209e1da6291SKrzysztof Drewniak     (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
210e1da6291SKrzysztof Drewniak   } while (moduleOp.lookupSymbol(stringConstName));
211e1da6291SKrzysztof Drewniak 
212564bcf9dSMehdi Amini   llvm::SmallString<20> formatString(adaptor.format());
213e1da6291SKrzysztof Drewniak   formatString.push_back('\0'); // Null terminate for C
214e1da6291SKrzysztof Drewniak   size_t formatStringSize = formatString.size_in_bytes();
215e1da6291SKrzysztof Drewniak 
216e1da6291SKrzysztof Drewniak   auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
217e1da6291SKrzysztof Drewniak   LLVM::GlobalOp global;
218e1da6291SKrzysztof Drewniak   {
219e1da6291SKrzysztof Drewniak     ConversionPatternRewriter::InsertionGuard guard(rewriter);
220e1da6291SKrzysztof Drewniak     rewriter.setInsertionPointToStart(moduleOp.getBody());
221e1da6291SKrzysztof Drewniak     global = rewriter.create<LLVM::GlobalOp>(
222e1da6291SKrzysztof Drewniak         loc, globalType,
223e1da6291SKrzysztof Drewniak         /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
224e1da6291SKrzysztof Drewniak         rewriter.getStringAttr(formatString));
225e1da6291SKrzysztof Drewniak   }
226e1da6291SKrzysztof Drewniak 
227e1da6291SKrzysztof Drewniak   // Get a pointer to the format string's first element and pass it to printf()
228e1da6291SKrzysztof Drewniak   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
229e1da6291SKrzysztof Drewniak   Value zero = rewriter.create<LLVM::ConstantOp>(
230e1da6291SKrzysztof Drewniak       loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
231e1da6291SKrzysztof Drewniak   Value stringStart = rewriter.create<LLVM::GEPOp>(
232e1da6291SKrzysztof Drewniak       loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
233e1da6291SKrzysztof Drewniak   Value stringLen = rewriter.create<LLVM::ConstantOp>(
234e1da6291SKrzysztof Drewniak       loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize));
235e1da6291SKrzysztof Drewniak 
236e1da6291SKrzysztof Drewniak   Value oneI32 = rewriter.create<LLVM::ConstantOp>(
237e1da6291SKrzysztof Drewniak       loc, llvmI32, rewriter.getI32IntegerAttr(1));
238e1da6291SKrzysztof Drewniak   Value zeroI32 = rewriter.create<LLVM::ConstantOp>(
239e1da6291SKrzysztof Drewniak       loc, llvmI32, rewriter.getI32IntegerAttr(0));
240e1da6291SKrzysztof Drewniak 
24179a0330aSMehdi Amini   auto appendFormatCall = rewriter.create<LLVM::CallOp>(
24279a0330aSMehdi Amini       loc, ocklAppendStringN,
24379a0330aSMehdi Amini       ValueRange{printfDesc, stringStart, stringLen,
24479a0330aSMehdi Amini                  adaptor.args().empty() ? oneI32 : zeroI32});
245e1da6291SKrzysztof Drewniak   printfDesc = appendFormatCall.getResult(0);
246e1da6291SKrzysztof Drewniak 
247e1da6291SKrzysztof Drewniak   // __ockl_printf_append_args takes 7 values per append call
248e1da6291SKrzysztof Drewniak   constexpr size_t argsPerAppend = 7;
249e1da6291SKrzysztof Drewniak   size_t nArgs = adaptor.args().size();
250e1da6291SKrzysztof Drewniak   for (size_t group = 0; group < nArgs; group += argsPerAppend) {
251e1da6291SKrzysztof Drewniak     size_t bound = std::min(group + argsPerAppend, nArgs);
252e1da6291SKrzysztof Drewniak     size_t numArgsThisCall = bound - group;
253e1da6291SKrzysztof Drewniak 
254e1da6291SKrzysztof Drewniak     SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
255e1da6291SKrzysztof Drewniak     arguments.push_back(printfDesc);
256e1da6291SKrzysztof Drewniak     arguments.push_back(rewriter.create<LLVM::ConstantOp>(
257e1da6291SKrzysztof Drewniak         loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall)));
258e1da6291SKrzysztof Drewniak     for (size_t i = group; i < bound; ++i) {
259e1da6291SKrzysztof Drewniak       Value arg = adaptor.args()[i];
260e1da6291SKrzysztof Drewniak       if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
261e1da6291SKrzysztof Drewniak         if (!floatType.isF64())
262e1da6291SKrzysztof Drewniak           arg = rewriter.create<LLVM::FPExtOp>(
263e1da6291SKrzysztof Drewniak               loc, typeConverter->convertType(rewriter.getF64Type()), arg);
264e1da6291SKrzysztof Drewniak         arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
265e1da6291SKrzysztof Drewniak       }
266e1da6291SKrzysztof Drewniak       if (arg.getType().getIntOrFloatBitWidth() != 64)
267e1da6291SKrzysztof Drewniak         arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
268e1da6291SKrzysztof Drewniak 
269e1da6291SKrzysztof Drewniak       arguments.push_back(arg);
270e1da6291SKrzysztof Drewniak     }
271e1da6291SKrzysztof Drewniak     // Pad out to 7 arguments since the hostcall always needs 7
272e1da6291SKrzysztof Drewniak     for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
273e1da6291SKrzysztof Drewniak       arguments.push_back(zeroI64);
274e1da6291SKrzysztof Drewniak     }
275e1da6291SKrzysztof Drewniak 
276e1da6291SKrzysztof Drewniak     auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
277e1da6291SKrzysztof Drewniak     arguments.push_back(isLast);
278e1da6291SKrzysztof Drewniak     auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
279e1da6291SKrzysztof Drewniak     printfDesc = call.getResult(0);
280e1da6291SKrzysztof Drewniak   }
281e1da6291SKrzysztof Drewniak   rewriter.eraseOp(gpuPrintfOp);
282e1da6291SKrzysztof Drewniak   return success();
283e1da6291SKrzysztof Drewniak }
284e1da6291SKrzysztof Drewniak 
matchAndRewrite(gpu::PrintfOp gpuPrintfOp,gpu::PrintfOpAdaptor adaptor,ConversionPatternRewriter & rewriter) const285e1da6291SKrzysztof Drewniak LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
286e1da6291SKrzysztof Drewniak     gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
287e1da6291SKrzysztof Drewniak     ConversionPatternRewriter &rewriter) const {
288e1da6291SKrzysztof Drewniak   Location loc = gpuPrintfOp->getLoc();
289e1da6291SKrzysztof Drewniak 
290e1da6291SKrzysztof Drewniak   mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
291e1da6291SKrzysztof Drewniak   mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace);
292e1da6291SKrzysztof Drewniak   mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
293e1da6291SKrzysztof Drewniak 
294e1da6291SKrzysztof Drewniak   // Note: this is the GPUModule op, not the ModuleOp that surrounds it
295e1da6291SKrzysztof Drewniak   // This ensures that global constants and declarations are placed within
296e1da6291SKrzysztof Drewniak   // the device code, not the host code
297e1da6291SKrzysztof Drewniak   auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
298e1da6291SKrzysztof Drewniak 
299e1da6291SKrzysztof Drewniak   auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr},
300e1da6291SKrzysztof Drewniak                                                 /*isVarArg=*/true);
301e1da6291SKrzysztof Drewniak   LLVM::LLVMFuncOp printfDecl =
302e1da6291SKrzysztof Drewniak       getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
303e1da6291SKrzysztof Drewniak 
304e1da6291SKrzysztof Drewniak   // Create a global constant for the format string
305e1da6291SKrzysztof Drewniak   unsigned stringNumber = 0;
306e1da6291SKrzysztof Drewniak   SmallString<16> stringConstName;
307e1da6291SKrzysztof Drewniak   do {
308e1da6291SKrzysztof Drewniak     stringConstName.clear();
309e1da6291SKrzysztof Drewniak     (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
310e1da6291SKrzysztof Drewniak   } while (moduleOp.lookupSymbol(stringConstName));
311e1da6291SKrzysztof Drewniak 
312564bcf9dSMehdi Amini   llvm::SmallString<20> formatString(adaptor.format());
313e1da6291SKrzysztof Drewniak   formatString.push_back('\0'); // Null terminate for C
314e1da6291SKrzysztof Drewniak   auto globalType =
315e1da6291SKrzysztof Drewniak       LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
316e1da6291SKrzysztof Drewniak   LLVM::GlobalOp global;
317e1da6291SKrzysztof Drewniak   {
318e1da6291SKrzysztof Drewniak     ConversionPatternRewriter::InsertionGuard guard(rewriter);
319e1da6291SKrzysztof Drewniak     rewriter.setInsertionPointToStart(moduleOp.getBody());
320e1da6291SKrzysztof Drewniak     global = rewriter.create<LLVM::GlobalOp>(
321e1da6291SKrzysztof Drewniak         loc, globalType,
322e1da6291SKrzysztof Drewniak         /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
323e1da6291SKrzysztof Drewniak         rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
324e1da6291SKrzysztof Drewniak   }
325e1da6291SKrzysztof Drewniak 
326e1da6291SKrzysztof Drewniak   // Get a pointer to the format string's first element
327e1da6291SKrzysztof Drewniak   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
328e1da6291SKrzysztof Drewniak   Value zero = rewriter.create<LLVM::ConstantOp>(
329e1da6291SKrzysztof Drewniak       loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
330e1da6291SKrzysztof Drewniak   Value stringStart = rewriter.create<LLVM::GEPOp>(
331e1da6291SKrzysztof Drewniak       loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
332e1da6291SKrzysztof Drewniak 
333e1da6291SKrzysztof Drewniak   // Construct arguments and function call
334e1da6291SKrzysztof Drewniak   auto argsRange = adaptor.args();
335e1da6291SKrzysztof Drewniak   SmallVector<Value, 4> printfArgs;
336e1da6291SKrzysztof Drewniak   printfArgs.reserve(argsRange.size() + 1);
337e1da6291SKrzysztof Drewniak   printfArgs.push_back(stringStart);
338e1da6291SKrzysztof Drewniak   printfArgs.append(argsRange.begin(), argsRange.end());
339e1da6291SKrzysztof Drewniak 
340e1da6291SKrzysztof Drewniak   rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
341e1da6291SKrzysztof Drewniak   rewriter.eraseOp(gpuPrintfOp);
342e1da6291SKrzysztof Drewniak   return success();
343e1da6291SKrzysztof Drewniak }
344