1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "GPUOpsLowering.h" 10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 11 #include "mlir/IR/Builders.h" 12 #include "llvm/Support/FormatVariadic.h" 13 14 using namespace mlir; 15 16 LogicalResult 17 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, 18 ConversionPatternRewriter &rewriter) const { 19 Location loc = gpuFuncOp.getLoc(); 20 21 SmallVector<LLVM::GlobalOp, 3> workgroupBuffers; 22 workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); 23 for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { 24 Value attribution = en.value(); 25 26 auto type = attribution.getType().dyn_cast<MemRefType>(); 27 assert(type && type.hasStaticShape() && "unexpected type in attribution"); 28 29 uint64_t numElements = type.getNumElements(); 30 31 auto elementType = 32 typeConverter->convertType(type.getElementType()).template cast<Type>(); 33 auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); 34 std::string name = std::string( 35 llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); 36 auto globalOp = rewriter.create<LLVM::GlobalOp>( 37 gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, 38 LLVM::Linkage::Internal, name, /*value=*/Attribute(), 39 /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace()); 40 workgroupBuffers.push_back(globalOp); 41 } 42 43 // Rewrite the original GPU function to an LLVM function. 44 auto funcType = typeConverter->convertType(gpuFuncOp.getFunctionType()) 45 .template cast<LLVM::LLVMPointerType>() 46 .getElementType(); 47 48 // Remap proper input types. 49 TypeConverter::SignatureConversion signatureConversion( 50 gpuFuncOp.front().getNumArguments()); 51 getTypeConverter()->convertFunctionSignature( 52 gpuFuncOp.getFunctionType(), /*isVariadic=*/false, signatureConversion); 53 54 // Create the new function operation. Only copy those attributes that are 55 // not specific to function modeling. 56 SmallVector<NamedAttribute, 4> attributes; 57 for (const auto &attr : gpuFuncOp->getAttrs()) { 58 if (attr.getName() == SymbolTable::getSymbolAttrName() || 59 attr.getName() == FunctionOpInterface::getTypeAttrName() || 60 attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) 61 continue; 62 attributes.push_back(attr); 63 } 64 // Add a dialect specific kernel attribute in addition to GPU kernel 65 // attribute. The former is necessary for further translation while the 66 // latter is expected by gpu.launch_func. 67 if (gpuFuncOp.isKernel()) 68 attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); 69 auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 70 gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, 71 LLVM::Linkage::External, /*dsoLocal*/ false, attributes); 72 73 { 74 // Insert operations that correspond to converted workgroup and private 75 // memory attributions to the body of the function. This must operate on 76 // the original function, before the body region is inlined in the new 77 // function to maintain the relation between block arguments and the 78 // parent operation that assigns their semantics. 79 OpBuilder::InsertionGuard guard(rewriter); 80 81 // Rewrite workgroup memory attributions to addresses of global buffers. 82 rewriter.setInsertionPointToStart(&gpuFuncOp.front()); 83 unsigned numProperArguments = gpuFuncOp.getNumArguments(); 84 auto i32Type = IntegerType::get(rewriter.getContext(), 32); 85 86 Value zero = nullptr; 87 if (!workgroupBuffers.empty()) 88 zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 89 rewriter.getI32IntegerAttr(0)); 90 for (const auto &en : llvm::enumerate(workgroupBuffers)) { 91 LLVM::GlobalOp global = en.value(); 92 Value address = rewriter.create<LLVM::AddressOfOp>(loc, global); 93 auto elementType = 94 global.getType().cast<LLVM::LLVMArrayType>().getElementType(); 95 Value memory = rewriter.create<LLVM::GEPOp>( 96 loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()), 97 address, ArrayRef<Value>{zero, zero}); 98 99 // Build a memref descriptor pointing to the buffer to plug with the 100 // existing memref infrastructure. This may use more registers than 101 // otherwise necessary given that memref sizes are fixed, but we can try 102 // and canonicalize that away later. 103 Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; 104 auto type = attribution.getType().cast<MemRefType>(); 105 auto descr = MemRefDescriptor::fromStaticShape( 106 rewriter, loc, *getTypeConverter(), type, memory); 107 signatureConversion.remapInput(numProperArguments + en.index(), descr); 108 } 109 110 // Rewrite private memory attributions to alloca'ed buffers. 111 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); 112 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 113 for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { 114 Value attribution = en.value(); 115 auto type = attribution.getType().cast<MemRefType>(); 116 assert(type && type.hasStaticShape() && "unexpected type in attribution"); 117 118 // Explicitly drop memory space when lowering private memory 119 // attributions since NVVM models it as `alloca`s in the default 120 // memory space and does not support `alloca`s with addrspace(5). 121 auto ptrType = LLVM::LLVMPointerType::get( 122 typeConverter->convertType(type.getElementType()) 123 .template cast<Type>(), 124 allocaAddrSpace); 125 Value numElements = rewriter.create<LLVM::ConstantOp>( 126 gpuFuncOp.getLoc(), int64Ty, 127 rewriter.getI64IntegerAttr(type.getNumElements())); 128 Value allocated = rewriter.create<LLVM::AllocaOp>( 129 gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); 130 auto descr = MemRefDescriptor::fromStaticShape( 131 rewriter, loc, *getTypeConverter(), type, allocated); 132 signatureConversion.remapInput( 133 numProperArguments + numWorkgroupAttributions + en.index(), descr); 134 } 135 } 136 137 // Move the region to the new function, update the entry block signature. 138 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), 139 llvmFuncOp.end()); 140 if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter, 141 &signatureConversion))) 142 return failure(); 143 144 rewriter.eraseOp(gpuFuncOp); 145 return success(); 146 } 147 148 static const char formatStringPrefix[] = "printfFormat_"; 149 150 template <typename T> 151 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, 152 ConversionPatternRewriter &rewriter, 153 StringRef name, 154 LLVM::LLVMFunctionType type) { 155 LLVM::LLVMFuncOp ret; 156 if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { 157 ConversionPatternRewriter::InsertionGuard guard(rewriter); 158 rewriter.setInsertionPointToStart(moduleOp.getBody()); 159 ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type, 160 LLVM::Linkage::External); 161 } 162 return ret; 163 } 164 165 LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( 166 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, 167 ConversionPatternRewriter &rewriter) const { 168 Location loc = gpuPrintfOp->getLoc(); 169 170 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type()); 171 mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8); 172 mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType()); 173 mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); 174 mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); 175 // Note: this is the GPUModule op, not the ModuleOp that surrounds it 176 // This ensures that global constants and declarations are placed within 177 // the device code, not the host code 178 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); 179 180 auto ocklBegin = 181 getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin", 182 LLVM::LLVMFunctionType::get(llvmI64, {llvmI64})); 183 LLVM::LLVMFuncOp ocklAppendArgs; 184 if (!adaptor.args().empty()) { 185 ocklAppendArgs = getOrDefineFunction( 186 moduleOp, loc, rewriter, "__ockl_printf_append_args", 187 LLVM::LLVMFunctionType::get( 188 llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64, 189 llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32})); 190 } 191 auto ocklAppendStringN = getOrDefineFunction( 192 moduleOp, loc, rewriter, "__ockl_printf_append_string_n", 193 LLVM::LLVMFunctionType::get( 194 llvmI64, 195 {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); 196 197 /// Start the printf hostcall 198 Value zeroI64 = rewriter.create<LLVM::ConstantOp>( 199 loc, llvmI64, rewriter.getI64IntegerAttr(0)); 200 auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64); 201 Value printfDesc = printfBeginCall.getResult(0); 202 203 // Create a global constant for the format string 204 unsigned stringNumber = 0; 205 SmallString<16> stringConstName; 206 do { 207 stringConstName.clear(); 208 (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); 209 } while (moduleOp.lookupSymbol(stringConstName)); 210 211 llvm::SmallString<20> formatString(adaptor.format()); 212 formatString.push_back('\0'); // Null terminate for C 213 size_t formatStringSize = formatString.size_in_bytes(); 214 215 auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize); 216 LLVM::GlobalOp global; 217 { 218 ConversionPatternRewriter::InsertionGuard guard(rewriter); 219 rewriter.setInsertionPointToStart(moduleOp.getBody()); 220 global = rewriter.create<LLVM::GlobalOp>( 221 loc, globalType, 222 /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, 223 rewriter.getStringAttr(formatString)); 224 } 225 226 // Get a pointer to the format string's first element and pass it to printf() 227 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); 228 Value zero = rewriter.create<LLVM::ConstantOp>( 229 loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0)); 230 Value stringStart = rewriter.create<LLVM::GEPOp>( 231 loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero})); 232 Value stringLen = rewriter.create<LLVM::ConstantOp>( 233 loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize)); 234 235 Value oneI32 = rewriter.create<LLVM::ConstantOp>( 236 loc, llvmI32, rewriter.getI32IntegerAttr(1)); 237 Value zeroI32 = rewriter.create<LLVM::ConstantOp>( 238 loc, llvmI32, rewriter.getI32IntegerAttr(0)); 239 240 auto appendFormatCall = rewriter.create<LLVM::CallOp>( 241 loc, ocklAppendStringN, 242 ValueRange{printfDesc, stringStart, stringLen, 243 adaptor.args().empty() ? oneI32 : zeroI32}); 244 printfDesc = appendFormatCall.getResult(0); 245 246 // __ockl_printf_append_args takes 7 values per append call 247 constexpr size_t argsPerAppend = 7; 248 size_t nArgs = adaptor.args().size(); 249 for (size_t group = 0; group < nArgs; group += argsPerAppend) { 250 size_t bound = std::min(group + argsPerAppend, nArgs); 251 size_t numArgsThisCall = bound - group; 252 253 SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments; 254 arguments.push_back(printfDesc); 255 arguments.push_back(rewriter.create<LLVM::ConstantOp>( 256 loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall))); 257 for (size_t i = group; i < bound; ++i) { 258 Value arg = adaptor.args()[i]; 259 if (auto floatType = arg.getType().dyn_cast<FloatType>()) { 260 if (!floatType.isF64()) 261 arg = rewriter.create<LLVM::FPExtOp>( 262 loc, typeConverter->convertType(rewriter.getF64Type()), arg); 263 arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg); 264 } 265 if (arg.getType().getIntOrFloatBitWidth() != 64) 266 arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg); 267 268 arguments.push_back(arg); 269 } 270 // Pad out to 7 arguments since the hostcall always needs 7 271 for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) { 272 arguments.push_back(zeroI64); 273 } 274 275 auto isLast = (bound == nArgs) ? oneI32 : zeroI32; 276 arguments.push_back(isLast); 277 auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments); 278 printfDesc = call.getResult(0); 279 } 280 rewriter.eraseOp(gpuPrintfOp); 281 return success(); 282 } 283 284 LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( 285 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, 286 ConversionPatternRewriter &rewriter) const { 287 Location loc = gpuPrintfOp->getLoc(); 288 289 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); 290 mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace); 291 mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType()); 292 293 // Note: this is the GPUModule op, not the ModuleOp that surrounds it 294 // This ensures that global constants and declarations are placed within 295 // the device code, not the host code 296 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); 297 298 auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr}, 299 /*isVarArg=*/true); 300 LLVM::LLVMFuncOp printfDecl = 301 getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); 302 303 // Create a global constant for the format string 304 unsigned stringNumber = 0; 305 SmallString<16> stringConstName; 306 do { 307 stringConstName.clear(); 308 (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); 309 } while (moduleOp.lookupSymbol(stringConstName)); 310 311 llvm::SmallString<20> formatString(adaptor.format()); 312 formatString.push_back('\0'); // Null terminate for C 313 auto globalType = 314 LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); 315 LLVM::GlobalOp global; 316 { 317 ConversionPatternRewriter::InsertionGuard guard(rewriter); 318 rewriter.setInsertionPointToStart(moduleOp.getBody()); 319 global = rewriter.create<LLVM::GlobalOp>( 320 loc, globalType, 321 /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, 322 rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace); 323 } 324 325 // Get a pointer to the format string's first element 326 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); 327 Value zero = rewriter.create<LLVM::ConstantOp>( 328 loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0)); 329 Value stringStart = rewriter.create<LLVM::GEPOp>( 330 loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero})); 331 332 // Construct arguments and function call 333 auto argsRange = adaptor.args(); 334 SmallVector<Value, 4> printfArgs; 335 printfArgs.reserve(argsRange.size() + 1); 336 printfArgs.push_back(stringStart); 337 printfArgs.append(argsRange.begin(), argsRange.end()); 338 339 rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs); 340 rewriter.eraseOp(gpuPrintfOp); 341 return success(); 342 } 343