14c4876c3SAlex Zinenko //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===//
24c4876c3SAlex Zinenko //
34c4876c3SAlex Zinenko // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44c4876c3SAlex Zinenko // See https://llvm.org/LICENSE.txt for license information.
54c4876c3SAlex Zinenko // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64c4876c3SAlex Zinenko //
74c4876c3SAlex Zinenko //===----------------------------------------------------------------------===//
84c4876c3SAlex Zinenko
94c4876c3SAlex Zinenko #include "GPUOpsLowering.h"
10e1da6291SKrzysztof Drewniak #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
114c4876c3SAlex Zinenko #include "mlir/IR/Builders.h"
124c4876c3SAlex Zinenko #include "llvm/Support/FormatVariadic.h"
134c4876c3SAlex Zinenko
144c4876c3SAlex Zinenko using namespace mlir;
154c4876c3SAlex Zinenko
164c4876c3SAlex Zinenko LogicalResult
matchAndRewrite(gpu::GPUFuncOp gpuFuncOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const17ef976337SRiver Riddle GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor,
184c4876c3SAlex Zinenko ConversionPatternRewriter &rewriter) const {
194c4876c3SAlex Zinenko Location loc = gpuFuncOp.getLoc();
204c4876c3SAlex Zinenko
214c4876c3SAlex Zinenko SmallVector<LLVM::GlobalOp, 3> workgroupBuffers;
224c4876c3SAlex Zinenko workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions());
23e4853be2SMehdi Amini for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) {
244c4876c3SAlex Zinenko Value attribution = en.value();
254c4876c3SAlex Zinenko
264c4876c3SAlex Zinenko auto type = attribution.getType().dyn_cast<MemRefType>();
274c4876c3SAlex Zinenko assert(type && type.hasStaticShape() && "unexpected type in attribution");
284c4876c3SAlex Zinenko
294c4876c3SAlex Zinenko uint64_t numElements = type.getNumElements();
304c4876c3SAlex Zinenko
314c4876c3SAlex Zinenko auto elementType =
324c4876c3SAlex Zinenko typeConverter->convertType(type.getElementType()).template cast<Type>();
334c4876c3SAlex Zinenko auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements);
344c4876c3SAlex Zinenko std::string name = std::string(
354c4876c3SAlex Zinenko llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index()));
364c4876c3SAlex Zinenko auto globalOp = rewriter.create<LLVM::GlobalOp>(
374c4876c3SAlex Zinenko gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false,
384c4876c3SAlex Zinenko LLVM::Linkage::Internal, name, /*value=*/Attribute(),
399a0ea599SDumitru Potop /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace());
404c4876c3SAlex Zinenko workgroupBuffers.push_back(globalOp);
414c4876c3SAlex Zinenko }
424c4876c3SAlex Zinenko
434c4876c3SAlex Zinenko // Rewrite the original GPU function to an LLVM function.
444a3460a7SRiver Riddle auto funcType = typeConverter->convertType(gpuFuncOp.getFunctionType())
454c4876c3SAlex Zinenko .template cast<LLVM::LLVMPointerType>()
464c4876c3SAlex Zinenko .getElementType();
474c4876c3SAlex Zinenko
484c4876c3SAlex Zinenko // Remap proper input types.
494c4876c3SAlex Zinenko TypeConverter::SignatureConversion signatureConversion(
504c4876c3SAlex Zinenko gpuFuncOp.front().getNumArguments());
514c4876c3SAlex Zinenko getTypeConverter()->convertFunctionSignature(
524a3460a7SRiver Riddle gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion);
534c4876c3SAlex Zinenko
544c4876c3SAlex Zinenko // Create the new function operation. Only copy those attributes that are
554c4876c3SAlex Zinenko // not specific to function modeling.
564c4876c3SAlex Zinenko SmallVector<NamedAttribute, 4> attributes;
5756774bddSMarius Brehler for (const auto &attr : gpuFuncOp->getAttrs()) {
580c7890c8SRiver Riddle if (attr.getName() == SymbolTable::getSymbolAttrName() ||
597ceffae1SRiver Riddle attr.getName() == FunctionOpInterface::getTypeAttrName() ||
600c7890c8SRiver Riddle attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName())
614c4876c3SAlex Zinenko continue;
624c4876c3SAlex Zinenko attributes.push_back(attr);
634c4876c3SAlex Zinenko }
644c4876c3SAlex Zinenko // Add a dialect specific kernel attribute in addition to GPU kernel
654c4876c3SAlex Zinenko // attribute. The former is necessary for further translation while the
664c4876c3SAlex Zinenko // latter is expected by gpu.launch_func.
674c4876c3SAlex Zinenko if (gpuFuncOp.isKernel())
684c4876c3SAlex Zinenko attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr());
694c4876c3SAlex Zinenko auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>(
704c4876c3SAlex Zinenko gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType,
71*0252357bSAlexander Batashev LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C,
72*0252357bSAlexander Batashev attributes);
734c4876c3SAlex Zinenko
744c4876c3SAlex Zinenko {
754c4876c3SAlex Zinenko // Insert operations that correspond to converted workgroup and private
764c4876c3SAlex Zinenko // memory attributions to the body of the function. This must operate on
774c4876c3SAlex Zinenko // the original function, before the body region is inlined in the new
784c4876c3SAlex Zinenko // function to maintain the relation between block arguments and the
794c4876c3SAlex Zinenko // parent operation that assigns their semantics.
804c4876c3SAlex Zinenko OpBuilder::InsertionGuard guard(rewriter);
814c4876c3SAlex Zinenko
824c4876c3SAlex Zinenko // Rewrite workgroup memory attributions to addresses of global buffers.
834c4876c3SAlex Zinenko rewriter.setInsertionPointToStart(&gpuFuncOp.front());
844c4876c3SAlex Zinenko unsigned numProperArguments = gpuFuncOp.getNumArguments();
854c4876c3SAlex Zinenko auto i32Type = IntegerType::get(rewriter.getContext(), 32);
864c4876c3SAlex Zinenko
874c4876c3SAlex Zinenko Value zero = nullptr;
884c4876c3SAlex Zinenko if (!workgroupBuffers.empty())
894c4876c3SAlex Zinenko zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type,
904c4876c3SAlex Zinenko rewriter.getI32IntegerAttr(0));
91e4853be2SMehdi Amini for (const auto &en : llvm::enumerate(workgroupBuffers)) {
924c4876c3SAlex Zinenko LLVM::GlobalOp global = en.value();
934c4876c3SAlex Zinenko Value address = rewriter.create<LLVM::AddressOfOp>(loc, global);
944c4876c3SAlex Zinenko auto elementType =
954c4876c3SAlex Zinenko global.getType().cast<LLVM::LLVMArrayType>().getElementType();
964c4876c3SAlex Zinenko Value memory = rewriter.create<LLVM::GEPOp>(
97cfb72fd3SJacques Pienaar loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()),
984c4876c3SAlex Zinenko address, ArrayRef<Value>{zero, zero});
994c4876c3SAlex Zinenko
1004c4876c3SAlex Zinenko // Build a memref descriptor pointing to the buffer to plug with the
1014c4876c3SAlex Zinenko // existing memref infrastructure. This may use more registers than
1024c4876c3SAlex Zinenko // otherwise necessary given that memref sizes are fixed, but we can try
1034c4876c3SAlex Zinenko // and canonicalize that away later.
1044c4876c3SAlex Zinenko Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()];
1054c4876c3SAlex Zinenko auto type = attribution.getType().cast<MemRefType>();
1064c4876c3SAlex Zinenko auto descr = MemRefDescriptor::fromStaticShape(
1074c4876c3SAlex Zinenko rewriter, loc, *getTypeConverter(), type, memory);
1084c4876c3SAlex Zinenko signatureConversion.remapInput(numProperArguments + en.index(), descr);
1094c4876c3SAlex Zinenko }
1104c4876c3SAlex Zinenko
1114c4876c3SAlex Zinenko // Rewrite private memory attributions to alloca'ed buffers.
1124c4876c3SAlex Zinenko unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions();
1134c4876c3SAlex Zinenko auto int64Ty = IntegerType::get(rewriter.getContext(), 64);
114e4853be2SMehdi Amini for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) {
1154c4876c3SAlex Zinenko Value attribution = en.value();
1164c4876c3SAlex Zinenko auto type = attribution.getType().cast<MemRefType>();
1174c4876c3SAlex Zinenko assert(type && type.hasStaticShape() && "unexpected type in attribution");
1184c4876c3SAlex Zinenko
1194c4876c3SAlex Zinenko // Explicitly drop memory space when lowering private memory
1204c4876c3SAlex Zinenko // attributions since NVVM models it as `alloca`s in the default
1214c4876c3SAlex Zinenko // memory space and does not support `alloca`s with addrspace(5).
1224c4876c3SAlex Zinenko auto ptrType = LLVM::LLVMPointerType::get(
1234c4876c3SAlex Zinenko typeConverter->convertType(type.getElementType())
1244c4876c3SAlex Zinenko .template cast<Type>(),
1254c4876c3SAlex Zinenko allocaAddrSpace);
1264c4876c3SAlex Zinenko Value numElements = rewriter.create<LLVM::ConstantOp>(
1274c4876c3SAlex Zinenko gpuFuncOp.getLoc(), int64Ty,
1284c4876c3SAlex Zinenko rewriter.getI64IntegerAttr(type.getNumElements()));
1294c4876c3SAlex Zinenko Value allocated = rewriter.create<LLVM::AllocaOp>(
1304c4876c3SAlex Zinenko gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0);
1314c4876c3SAlex Zinenko auto descr = MemRefDescriptor::fromStaticShape(
1324c4876c3SAlex Zinenko rewriter, loc, *getTypeConverter(), type, allocated);
1334c4876c3SAlex Zinenko signatureConversion.remapInput(
1344c4876c3SAlex Zinenko numProperArguments + numWorkgroupAttributions + en.index(), descr);
1354c4876c3SAlex Zinenko }
1364c4876c3SAlex Zinenko }
1374c4876c3SAlex Zinenko
1384c4876c3SAlex Zinenko // Move the region to the new function, update the entry block signature.
1394c4876c3SAlex Zinenko rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(),
1404c4876c3SAlex Zinenko llvmFuncOp.end());
1414c4876c3SAlex Zinenko if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter,
1424c4876c3SAlex Zinenko &signatureConversion)))
1434c4876c3SAlex Zinenko return failure();
1444c4876c3SAlex Zinenko
1454c4876c3SAlex Zinenko rewriter.eraseOp(gpuFuncOp);
1464c4876c3SAlex Zinenko return success();
1474c4876c3SAlex Zinenko }
148e1da6291SKrzysztof Drewniak
149e1da6291SKrzysztof Drewniak static const char formatStringPrefix[] = "printfFormat_";
150e1da6291SKrzysztof Drewniak
151e1da6291SKrzysztof Drewniak template <typename T>
getOrDefineFunction(T & moduleOp,const Location loc,ConversionPatternRewriter & rewriter,StringRef name,LLVM::LLVMFunctionType type)152e1da6291SKrzysztof Drewniak static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
153e1da6291SKrzysztof Drewniak ConversionPatternRewriter &rewriter,
154e1da6291SKrzysztof Drewniak StringRef name,
155e1da6291SKrzysztof Drewniak LLVM::LLVMFunctionType type) {
156e1da6291SKrzysztof Drewniak LLVM::LLVMFuncOp ret;
157e1da6291SKrzysztof Drewniak if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) {
158e1da6291SKrzysztof Drewniak ConversionPatternRewriter::InsertionGuard guard(rewriter);
159e1da6291SKrzysztof Drewniak rewriter.setInsertionPointToStart(moduleOp.getBody());
160e1da6291SKrzysztof Drewniak ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type,
161e1da6291SKrzysztof Drewniak LLVM::Linkage::External);
162e1da6291SKrzysztof Drewniak }
163e1da6291SKrzysztof Drewniak return ret;
164e1da6291SKrzysztof Drewniak }
165e1da6291SKrzysztof Drewniak
matchAndRewrite(gpu::PrintfOp gpuPrintfOp,gpu::PrintfOpAdaptor adaptor,ConversionPatternRewriter & rewriter) const166e1da6291SKrzysztof Drewniak LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
167e1da6291SKrzysztof Drewniak gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
168e1da6291SKrzysztof Drewniak ConversionPatternRewriter &rewriter) const {
169e1da6291SKrzysztof Drewniak Location loc = gpuPrintfOp->getLoc();
170e1da6291SKrzysztof Drewniak
171e1da6291SKrzysztof Drewniak mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type());
172e1da6291SKrzysztof Drewniak mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8);
173e1da6291SKrzysztof Drewniak mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
174e1da6291SKrzysztof Drewniak mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type());
175e1da6291SKrzysztof Drewniak mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type());
176e1da6291SKrzysztof Drewniak // Note: this is the GPUModule op, not the ModuleOp that surrounds it
177e1da6291SKrzysztof Drewniak // This ensures that global constants and declarations are placed within
178e1da6291SKrzysztof Drewniak // the device code, not the host code
179e1da6291SKrzysztof Drewniak auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
180e1da6291SKrzysztof Drewniak
181e1da6291SKrzysztof Drewniak auto ocklBegin =
182e1da6291SKrzysztof Drewniak getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin",
183e1da6291SKrzysztof Drewniak LLVM::LLVMFunctionType::get(llvmI64, {llvmI64}));
184e1da6291SKrzysztof Drewniak LLVM::LLVMFuncOp ocklAppendArgs;
185e1da6291SKrzysztof Drewniak if (!adaptor.args().empty()) {
186e1da6291SKrzysztof Drewniak ocklAppendArgs = getOrDefineFunction(
187e1da6291SKrzysztof Drewniak moduleOp, loc, rewriter, "__ockl_printf_append_args",
188e1da6291SKrzysztof Drewniak LLVM::LLVMFunctionType::get(
189e1da6291SKrzysztof Drewniak llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64,
190e1da6291SKrzysztof Drewniak llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32}));
191e1da6291SKrzysztof Drewniak }
192e1da6291SKrzysztof Drewniak auto ocklAppendStringN = getOrDefineFunction(
193e1da6291SKrzysztof Drewniak moduleOp, loc, rewriter, "__ockl_printf_append_string_n",
194e1da6291SKrzysztof Drewniak LLVM::LLVMFunctionType::get(
195e1da6291SKrzysztof Drewniak llvmI64,
196e1da6291SKrzysztof Drewniak {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32}));
197e1da6291SKrzysztof Drewniak
198e1da6291SKrzysztof Drewniak /// Start the printf hostcall
199e1da6291SKrzysztof Drewniak Value zeroI64 = rewriter.create<LLVM::ConstantOp>(
200e1da6291SKrzysztof Drewniak loc, llvmI64, rewriter.getI64IntegerAttr(0));
201e1da6291SKrzysztof Drewniak auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
202e1da6291SKrzysztof Drewniak Value printfDesc = printfBeginCall.getResult(0);
203e1da6291SKrzysztof Drewniak
204e1da6291SKrzysztof Drewniak // Create a global constant for the format string
205e1da6291SKrzysztof Drewniak unsigned stringNumber = 0;
206e1da6291SKrzysztof Drewniak SmallString<16> stringConstName;
207e1da6291SKrzysztof Drewniak do {
208e1da6291SKrzysztof Drewniak stringConstName.clear();
209e1da6291SKrzysztof Drewniak (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
210e1da6291SKrzysztof Drewniak } while (moduleOp.lookupSymbol(stringConstName));
211e1da6291SKrzysztof Drewniak
212564bcf9dSMehdi Amini llvm::SmallString<20> formatString(adaptor.format());
213e1da6291SKrzysztof Drewniak formatString.push_back('\0'); // Null terminate for C
214e1da6291SKrzysztof Drewniak size_t formatStringSize = formatString.size_in_bytes();
215e1da6291SKrzysztof Drewniak
216e1da6291SKrzysztof Drewniak auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
217e1da6291SKrzysztof Drewniak LLVM::GlobalOp global;
218e1da6291SKrzysztof Drewniak {
219e1da6291SKrzysztof Drewniak ConversionPatternRewriter::InsertionGuard guard(rewriter);
220e1da6291SKrzysztof Drewniak rewriter.setInsertionPointToStart(moduleOp.getBody());
221e1da6291SKrzysztof Drewniak global = rewriter.create<LLVM::GlobalOp>(
222e1da6291SKrzysztof Drewniak loc, globalType,
223e1da6291SKrzysztof Drewniak /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
224e1da6291SKrzysztof Drewniak rewriter.getStringAttr(formatString));
225e1da6291SKrzysztof Drewniak }
226e1da6291SKrzysztof Drewniak
227e1da6291SKrzysztof Drewniak // Get a pointer to the format string's first element and pass it to printf()
228e1da6291SKrzysztof Drewniak Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
229e1da6291SKrzysztof Drewniak Value zero = rewriter.create<LLVM::ConstantOp>(
230e1da6291SKrzysztof Drewniak loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
231e1da6291SKrzysztof Drewniak Value stringStart = rewriter.create<LLVM::GEPOp>(
232e1da6291SKrzysztof Drewniak loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
233e1da6291SKrzysztof Drewniak Value stringLen = rewriter.create<LLVM::ConstantOp>(
234e1da6291SKrzysztof Drewniak loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize));
235e1da6291SKrzysztof Drewniak
236e1da6291SKrzysztof Drewniak Value oneI32 = rewriter.create<LLVM::ConstantOp>(
237e1da6291SKrzysztof Drewniak loc, llvmI32, rewriter.getI32IntegerAttr(1));
238e1da6291SKrzysztof Drewniak Value zeroI32 = rewriter.create<LLVM::ConstantOp>(
239e1da6291SKrzysztof Drewniak loc, llvmI32, rewriter.getI32IntegerAttr(0));
240e1da6291SKrzysztof Drewniak
24179a0330aSMehdi Amini auto appendFormatCall = rewriter.create<LLVM::CallOp>(
24279a0330aSMehdi Amini loc, ocklAppendStringN,
24379a0330aSMehdi Amini ValueRange{printfDesc, stringStart, stringLen,
24479a0330aSMehdi Amini adaptor.args().empty() ? oneI32 : zeroI32});
245e1da6291SKrzysztof Drewniak printfDesc = appendFormatCall.getResult(0);
246e1da6291SKrzysztof Drewniak
247e1da6291SKrzysztof Drewniak // __ockl_printf_append_args takes 7 values per append call
248e1da6291SKrzysztof Drewniak constexpr size_t argsPerAppend = 7;
249e1da6291SKrzysztof Drewniak size_t nArgs = adaptor.args().size();
250e1da6291SKrzysztof Drewniak for (size_t group = 0; group < nArgs; group += argsPerAppend) {
251e1da6291SKrzysztof Drewniak size_t bound = std::min(group + argsPerAppend, nArgs);
252e1da6291SKrzysztof Drewniak size_t numArgsThisCall = bound - group;
253e1da6291SKrzysztof Drewniak
254e1da6291SKrzysztof Drewniak SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments;
255e1da6291SKrzysztof Drewniak arguments.push_back(printfDesc);
256e1da6291SKrzysztof Drewniak arguments.push_back(rewriter.create<LLVM::ConstantOp>(
257e1da6291SKrzysztof Drewniak loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall)));
258e1da6291SKrzysztof Drewniak for (size_t i = group; i < bound; ++i) {
259e1da6291SKrzysztof Drewniak Value arg = adaptor.args()[i];
260e1da6291SKrzysztof Drewniak if (auto floatType = arg.getType().dyn_cast<FloatType>()) {
261e1da6291SKrzysztof Drewniak if (!floatType.isF64())
262e1da6291SKrzysztof Drewniak arg = rewriter.create<LLVM::FPExtOp>(
263e1da6291SKrzysztof Drewniak loc, typeConverter->convertType(rewriter.getF64Type()), arg);
264e1da6291SKrzysztof Drewniak arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg);
265e1da6291SKrzysztof Drewniak }
266e1da6291SKrzysztof Drewniak if (arg.getType().getIntOrFloatBitWidth() != 64)
267e1da6291SKrzysztof Drewniak arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg);
268e1da6291SKrzysztof Drewniak
269e1da6291SKrzysztof Drewniak arguments.push_back(arg);
270e1da6291SKrzysztof Drewniak }
271e1da6291SKrzysztof Drewniak // Pad out to 7 arguments since the hostcall always needs 7
272e1da6291SKrzysztof Drewniak for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) {
273e1da6291SKrzysztof Drewniak arguments.push_back(zeroI64);
274e1da6291SKrzysztof Drewniak }
275e1da6291SKrzysztof Drewniak
276e1da6291SKrzysztof Drewniak auto isLast = (bound == nArgs) ? oneI32 : zeroI32;
277e1da6291SKrzysztof Drewniak arguments.push_back(isLast);
278e1da6291SKrzysztof Drewniak auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments);
279e1da6291SKrzysztof Drewniak printfDesc = call.getResult(0);
280e1da6291SKrzysztof Drewniak }
281e1da6291SKrzysztof Drewniak rewriter.eraseOp(gpuPrintfOp);
282e1da6291SKrzysztof Drewniak return success();
283e1da6291SKrzysztof Drewniak }
284e1da6291SKrzysztof Drewniak
matchAndRewrite(gpu::PrintfOp gpuPrintfOp,gpu::PrintfOpAdaptor adaptor,ConversionPatternRewriter & rewriter) const285e1da6291SKrzysztof Drewniak LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
286e1da6291SKrzysztof Drewniak gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor,
287e1da6291SKrzysztof Drewniak ConversionPatternRewriter &rewriter) const {
288e1da6291SKrzysztof Drewniak Location loc = gpuPrintfOp->getLoc();
289e1da6291SKrzysztof Drewniak
290e1da6291SKrzysztof Drewniak mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8));
291e1da6291SKrzysztof Drewniak mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace);
292e1da6291SKrzysztof Drewniak mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType());
293e1da6291SKrzysztof Drewniak
294e1da6291SKrzysztof Drewniak // Note: this is the GPUModule op, not the ModuleOp that surrounds it
295e1da6291SKrzysztof Drewniak // This ensures that global constants and declarations are placed within
296e1da6291SKrzysztof Drewniak // the device code, not the host code
297e1da6291SKrzysztof Drewniak auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>();
298e1da6291SKrzysztof Drewniak
299e1da6291SKrzysztof Drewniak auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr},
300e1da6291SKrzysztof Drewniak /*isVarArg=*/true);
301e1da6291SKrzysztof Drewniak LLVM::LLVMFuncOp printfDecl =
302e1da6291SKrzysztof Drewniak getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
303e1da6291SKrzysztof Drewniak
304e1da6291SKrzysztof Drewniak // Create a global constant for the format string
305e1da6291SKrzysztof Drewniak unsigned stringNumber = 0;
306e1da6291SKrzysztof Drewniak SmallString<16> stringConstName;
307e1da6291SKrzysztof Drewniak do {
308e1da6291SKrzysztof Drewniak stringConstName.clear();
309e1da6291SKrzysztof Drewniak (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
310e1da6291SKrzysztof Drewniak } while (moduleOp.lookupSymbol(stringConstName));
311e1da6291SKrzysztof Drewniak
312564bcf9dSMehdi Amini llvm::SmallString<20> formatString(adaptor.format());
313e1da6291SKrzysztof Drewniak formatString.push_back('\0'); // Null terminate for C
314e1da6291SKrzysztof Drewniak auto globalType =
315e1da6291SKrzysztof Drewniak LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
316e1da6291SKrzysztof Drewniak LLVM::GlobalOp global;
317e1da6291SKrzysztof Drewniak {
318e1da6291SKrzysztof Drewniak ConversionPatternRewriter::InsertionGuard guard(rewriter);
319e1da6291SKrzysztof Drewniak rewriter.setInsertionPointToStart(moduleOp.getBody());
320e1da6291SKrzysztof Drewniak global = rewriter.create<LLVM::GlobalOp>(
321e1da6291SKrzysztof Drewniak loc, globalType,
322e1da6291SKrzysztof Drewniak /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
323e1da6291SKrzysztof Drewniak rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
324e1da6291SKrzysztof Drewniak }
325e1da6291SKrzysztof Drewniak
326e1da6291SKrzysztof Drewniak // Get a pointer to the format string's first element
327e1da6291SKrzysztof Drewniak Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
328e1da6291SKrzysztof Drewniak Value zero = rewriter.create<LLVM::ConstantOp>(
329e1da6291SKrzysztof Drewniak loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0));
330e1da6291SKrzysztof Drewniak Value stringStart = rewriter.create<LLVM::GEPOp>(
331e1da6291SKrzysztof Drewniak loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero}));
332e1da6291SKrzysztof Drewniak
333e1da6291SKrzysztof Drewniak // Construct arguments and function call
334e1da6291SKrzysztof Drewniak auto argsRange = adaptor.args();
335e1da6291SKrzysztof Drewniak SmallVector<Value, 4> printfArgs;
336e1da6291SKrzysztof Drewniak printfArgs.reserve(argsRange.size() + 1);
337e1da6291SKrzysztof Drewniak printfArgs.push_back(stringStart);
338e1da6291SKrzysztof Drewniak printfArgs.append(argsRange.begin(), argsRange.end());
339e1da6291SKrzysztof Drewniak
340e1da6291SKrzysztof Drewniak rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs);
341e1da6291SKrzysztof Drewniak rewriter.eraseOp(gpuPrintfOp);
342e1da6291SKrzysztof Drewniak return success();
343e1da6291SKrzysztof Drewniak }
344