//===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "GPUOpsLowering.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Builders.h" #include "llvm/Support/FormatVariadic.h" using namespace mlir; LogicalResult GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = gpuFuncOp.getLoc(); SmallVector workgroupBuffers; workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { Value attribution = en.value(); auto type = attribution.getType().dyn_cast(); assert(type && type.hasStaticShape() && "unexpected type in attribution"); uint64_t numElements = type.getNumElements(); auto elementType = typeConverter->convertType(type.getElementType()).template cast(); auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); std::string name = std::string( llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); auto globalOp = rewriter.create( gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, LLVM::Linkage::Internal, name, /*value=*/Attribute(), /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace()); workgroupBuffers.push_back(globalOp); } // Rewrite the original GPU function to an LLVM function. auto funcType = typeConverter->convertType(gpuFuncOp.getFunctionType()) .template cast() .getElementType(); // Remap proper input types. TypeConverter::SignatureConversion signatureConversion( gpuFuncOp.front().getNumArguments()); getTypeConverter()->convertFunctionSignature( gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion); // Create the new function operation. Only copy those attributes that are // not specific to function modeling. SmallVector attributes; for (const auto &attr : gpuFuncOp->getAttrs()) { if (attr.getName() == SymbolTable::getSymbolAttrName() || attr.getName() == FunctionOpInterface::getTypeAttrName() || attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) continue; attributes.push_back(attr); } // Add a dialect specific kernel attribute in addition to GPU kernel // attribute. The former is necessary for further translation while the // latter is expected by gpu.launch_func. if (gpuFuncOp.isKernel()) attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); auto llvmFuncOp = rewriter.create( gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, LLVM::Linkage::External, /*dsoLocal*/ false, /*cconv*/ LLVM::CConv::C, attributes); { // Insert operations that correspond to converted workgroup and private // memory attributions to the body of the function. This must operate on // the original function, before the body region is inlined in the new // function to maintain the relation between block arguments and the // parent operation that assigns their semantics. OpBuilder::InsertionGuard guard(rewriter); // Rewrite workgroup memory attributions to addresses of global buffers. rewriter.setInsertionPointToStart(&gpuFuncOp.front()); unsigned numProperArguments = gpuFuncOp.getNumArguments(); auto i32Type = IntegerType::get(rewriter.getContext(), 32); Value zero = nullptr; if (!workgroupBuffers.empty()) zero = rewriter.create(loc, i32Type, rewriter.getI32IntegerAttr(0)); for (const auto &en : llvm::enumerate(workgroupBuffers)) { LLVM::GlobalOp global = en.value(); Value address = rewriter.create(loc, global); auto elementType = global.getType().cast().getElementType(); Value memory = rewriter.create( loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()), address, ArrayRef{zero, zero}); // Build a memref descriptor pointing to the buffer to plug with the // existing memref infrastructure. This may use more registers than // otherwise necessary given that memref sizes are fixed, but we can try // and canonicalize that away later. Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; auto type = attribution.getType().cast(); auto descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, memory); signatureConversion.remapInput(numProperArguments + en.index(), descr); } // Rewrite private memory attributions to alloca'ed buffers. unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); auto int64Ty = IntegerType::get(rewriter.getContext(), 64); for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { Value attribution = en.value(); auto type = attribution.getType().cast(); assert(type && type.hasStaticShape() && "unexpected type in attribution"); // Explicitly drop memory space when lowering private memory // attributions since NVVM models it as `alloca`s in the default // memory space and does not support `alloca`s with addrspace(5). auto ptrType = LLVM::LLVMPointerType::get( typeConverter->convertType(type.getElementType()) .template cast(), allocaAddrSpace); Value numElements = rewriter.create( gpuFuncOp.getLoc(), int64Ty, rewriter.getI64IntegerAttr(type.getNumElements())); Value allocated = rewriter.create( gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); auto descr = MemRefDescriptor::fromStaticShape( rewriter, loc, *getTypeConverter(), type, allocated); signatureConversion.remapInput( numProperArguments + numWorkgroupAttributions + en.index(), descr); } } // Move the region to the new function, update the entry block signature. rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), llvmFuncOp.end()); if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter, &signatureConversion))) return failure(); rewriter.eraseOp(gpuFuncOp); return success(); } static const char formatStringPrefix[] = "printfFormat_"; template static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, ConversionPatternRewriter &rewriter, StringRef name, LLVM::LLVMFunctionType type) { LLVM::LLVMFuncOp ret; if (!(ret = moduleOp.template lookupSymbol(name))) { ConversionPatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); ret = rewriter.create(loc, name, type, LLVM::Linkage::External); } return ret; } LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = gpuPrintfOp->getLoc(); mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type()); mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8); mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType()); mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); // Note: this is the GPUModule op, not the ModuleOp that surrounds it // This ensures that global constants and declarations are placed within // the device code, not the host code auto moduleOp = gpuPrintfOp->getParentOfType(); auto ocklBegin = getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin", LLVM::LLVMFunctionType::get(llvmI64, {llvmI64})); LLVM::LLVMFuncOp ocklAppendArgs; if (!adaptor.args().empty()) { ocklAppendArgs = getOrDefineFunction( moduleOp, loc, rewriter, "__ockl_printf_append_args", LLVM::LLVMFunctionType::get( llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64, llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32})); } auto ocklAppendStringN = getOrDefineFunction( moduleOp, loc, rewriter, "__ockl_printf_append_string_n", LLVM::LLVMFunctionType::get( llvmI64, {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); /// Start the printf hostcall Value zeroI64 = rewriter.create( loc, llvmI64, rewriter.getI64IntegerAttr(0)); auto printfBeginCall = rewriter.create(loc, ocklBegin, zeroI64); Value printfDesc = printfBeginCall.getResult(0); // Create a global constant for the format string unsigned stringNumber = 0; SmallString<16> stringConstName; do { stringConstName.clear(); (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); } while (moduleOp.lookupSymbol(stringConstName)); llvm::SmallString<20> formatString(adaptor.format()); formatString.push_back('\0'); // Null terminate for C size_t formatStringSize = formatString.size_in_bytes(); auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize); LLVM::GlobalOp global; { ConversionPatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); global = rewriter.create( loc, globalType, /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, rewriter.getStringAttr(formatString)); } // Get a pointer to the format string's first element and pass it to printf() Value globalPtr = rewriter.create(loc, global); Value zero = rewriter.create( loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0)); Value stringStart = rewriter.create( loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero})); Value stringLen = rewriter.create( loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize)); Value oneI32 = rewriter.create( loc, llvmI32, rewriter.getI32IntegerAttr(1)); Value zeroI32 = rewriter.create( loc, llvmI32, rewriter.getI32IntegerAttr(0)); auto appendFormatCall = rewriter.create( loc, ocklAppendStringN, ValueRange{printfDesc, stringStart, stringLen, adaptor.args().empty() ? oneI32 : zeroI32}); printfDesc = appendFormatCall.getResult(0); // __ockl_printf_append_args takes 7 values per append call constexpr size_t argsPerAppend = 7; size_t nArgs = adaptor.args().size(); for (size_t group = 0; group < nArgs; group += argsPerAppend) { size_t bound = std::min(group + argsPerAppend, nArgs); size_t numArgsThisCall = bound - group; SmallVector arguments; arguments.push_back(printfDesc); arguments.push_back(rewriter.create( loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall))); for (size_t i = group; i < bound; ++i) { Value arg = adaptor.args()[i]; if (auto floatType = arg.getType().dyn_cast()) { if (!floatType.isF64()) arg = rewriter.create( loc, typeConverter->convertType(rewriter.getF64Type()), arg); arg = rewriter.create(loc, llvmI64, arg); } if (arg.getType().getIntOrFloatBitWidth() != 64) arg = rewriter.create(loc, llvmI64, arg); arguments.push_back(arg); } // Pad out to 7 arguments since the hostcall always needs 7 for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) { arguments.push_back(zeroI64); } auto isLast = (bound == nArgs) ? oneI32 : zeroI32; arguments.push_back(isLast); auto call = rewriter.create(loc, ocklAppendArgs, arguments); printfDesc = call.getResult(0); } rewriter.eraseOp(gpuPrintfOp); return success(); } LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Location loc = gpuPrintfOp->getLoc(); mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace); mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType()); // Note: this is the GPUModule op, not the ModuleOp that surrounds it // This ensures that global constants and declarations are placed within // the device code, not the host code auto moduleOp = gpuPrintfOp->getParentOfType(); auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr}, /*isVarArg=*/true); LLVM::LLVMFuncOp printfDecl = getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); // Create a global constant for the format string unsigned stringNumber = 0; SmallString<16> stringConstName; do { stringConstName.clear(); (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); } while (moduleOp.lookupSymbol(stringConstName)); llvm::SmallString<20> formatString(adaptor.format()); formatString.push_back('\0'); // Null terminate for C auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); LLVM::GlobalOp global; { ConversionPatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); global = rewriter.create( loc, globalType, /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace); } // Get a pointer to the format string's first element Value globalPtr = rewriter.create(loc, global); Value zero = rewriter.create( loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0)); Value stringStart = rewriter.create( loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero})); // Construct arguments and function call auto argsRange = adaptor.args(); SmallVector printfArgs; printfArgs.reserve(argsRange.size() + 1); printfArgs.push_back(stringStart); printfArgs.append(argsRange.begin(), argsRange.end()); rewriter.create(loc, printfDecl, printfArgs); rewriter.eraseOp(gpuPrintfOp); return success(); }