1 //===- GPUOpsLowering.cpp - GPU FuncOp / ReturnOp lowering ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "GPUOpsLowering.h" 10 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 11 #include "mlir/Dialect/StandardOps/IR/Ops.h" 12 #include "mlir/IR/Builders.h" 13 #include "llvm/Support/FormatVariadic.h" 14 15 using namespace mlir; 16 17 LogicalResult 18 GPUFuncOpLowering::matchAndRewrite(gpu::GPUFuncOp gpuFuncOp, OpAdaptor adaptor, 19 ConversionPatternRewriter &rewriter) const { 20 Location loc = gpuFuncOp.getLoc(); 21 22 SmallVector<LLVM::GlobalOp, 3> workgroupBuffers; 23 workgroupBuffers.reserve(gpuFuncOp.getNumWorkgroupAttributions()); 24 for (const auto &en : llvm::enumerate(gpuFuncOp.getWorkgroupAttributions())) { 25 Value attribution = en.value(); 26 27 auto type = attribution.getType().dyn_cast<MemRefType>(); 28 assert(type && type.hasStaticShape() && "unexpected type in attribution"); 29 30 uint64_t numElements = type.getNumElements(); 31 32 auto elementType = 33 typeConverter->convertType(type.getElementType()).template cast<Type>(); 34 auto arrayType = LLVM::LLVMArrayType::get(elementType, numElements); 35 std::string name = std::string( 36 llvm::formatv("__wg_{0}_{1}", gpuFuncOp.getName(), en.index())); 37 auto globalOp = rewriter.create<LLVM::GlobalOp>( 38 gpuFuncOp.getLoc(), arrayType, /*isConstant=*/false, 39 LLVM::Linkage::Internal, name, /*value=*/Attribute(), 40 /*alignment=*/0, gpu::GPUDialect::getWorkgroupAddressSpace()); 41 workgroupBuffers.push_back(globalOp); 42 } 43 44 // Rewrite the original GPU function to an LLVM function. 45 auto funcType = typeConverter->convertType(gpuFuncOp.getType()) 46 .template cast<LLVM::LLVMPointerType>() 47 .getElementType(); 48 49 // Remap proper input types. 50 TypeConverter::SignatureConversion signatureConversion( 51 gpuFuncOp.front().getNumArguments()); 52 getTypeConverter()->convertFunctionSignature( 53 gpuFuncOp.getType(), /*isVariadic=*/false, signatureConversion); 54 55 // Create the new function operation. Only copy those attributes that are 56 // not specific to function modeling. 57 SmallVector<NamedAttribute, 4> attributes; 58 for (const auto &attr : gpuFuncOp->getAttrs()) { 59 if (attr.getName() == SymbolTable::getSymbolAttrName() || 60 attr.getName() == FunctionOpInterface::getTypeAttrName() || 61 attr.getName() == gpu::GPUFuncOp::getNumWorkgroupAttributionsAttrName()) 62 continue; 63 attributes.push_back(attr); 64 } 65 // Add a dialect specific kernel attribute in addition to GPU kernel 66 // attribute. The former is necessary for further translation while the 67 // latter is expected by gpu.launch_func. 68 if (gpuFuncOp.isKernel()) 69 attributes.emplace_back(kernelAttributeName, rewriter.getUnitAttr()); 70 auto llvmFuncOp = rewriter.create<LLVM::LLVMFuncOp>( 71 gpuFuncOp.getLoc(), gpuFuncOp.getName(), funcType, 72 LLVM::Linkage::External, /*dsoLocal*/ false, attributes); 73 74 { 75 // Insert operations that correspond to converted workgroup and private 76 // memory attributions to the body of the function. This must operate on 77 // the original function, before the body region is inlined in the new 78 // function to maintain the relation between block arguments and the 79 // parent operation that assigns their semantics. 80 OpBuilder::InsertionGuard guard(rewriter); 81 82 // Rewrite workgroup memory attributions to addresses of global buffers. 83 rewriter.setInsertionPointToStart(&gpuFuncOp.front()); 84 unsigned numProperArguments = gpuFuncOp.getNumArguments(); 85 auto i32Type = IntegerType::get(rewriter.getContext(), 32); 86 87 Value zero = nullptr; 88 if (!workgroupBuffers.empty()) 89 zero = rewriter.create<LLVM::ConstantOp>(loc, i32Type, 90 rewriter.getI32IntegerAttr(0)); 91 for (const auto &en : llvm::enumerate(workgroupBuffers)) { 92 LLVM::GlobalOp global = en.value(); 93 Value address = rewriter.create<LLVM::AddressOfOp>(loc, global); 94 auto elementType = 95 global.getType().cast<LLVM::LLVMArrayType>().getElementType(); 96 Value memory = rewriter.create<LLVM::GEPOp>( 97 loc, LLVM::LLVMPointerType::get(elementType, global.getAddrSpace()), 98 address, ArrayRef<Value>{zero, zero}); 99 100 // Build a memref descriptor pointing to the buffer to plug with the 101 // existing memref infrastructure. This may use more registers than 102 // otherwise necessary given that memref sizes are fixed, but we can try 103 // and canonicalize that away later. 104 Value attribution = gpuFuncOp.getWorkgroupAttributions()[en.index()]; 105 auto type = attribution.getType().cast<MemRefType>(); 106 auto descr = MemRefDescriptor::fromStaticShape( 107 rewriter, loc, *getTypeConverter(), type, memory); 108 signatureConversion.remapInput(numProperArguments + en.index(), descr); 109 } 110 111 // Rewrite private memory attributions to alloca'ed buffers. 112 unsigned numWorkgroupAttributions = gpuFuncOp.getNumWorkgroupAttributions(); 113 auto int64Ty = IntegerType::get(rewriter.getContext(), 64); 114 for (const auto &en : llvm::enumerate(gpuFuncOp.getPrivateAttributions())) { 115 Value attribution = en.value(); 116 auto type = attribution.getType().cast<MemRefType>(); 117 assert(type && type.hasStaticShape() && "unexpected type in attribution"); 118 119 // Explicitly drop memory space when lowering private memory 120 // attributions since NVVM models it as `alloca`s in the default 121 // memory space and does not support `alloca`s with addrspace(5). 122 auto ptrType = LLVM::LLVMPointerType::get( 123 typeConverter->convertType(type.getElementType()) 124 .template cast<Type>(), 125 allocaAddrSpace); 126 Value numElements = rewriter.create<LLVM::ConstantOp>( 127 gpuFuncOp.getLoc(), int64Ty, 128 rewriter.getI64IntegerAttr(type.getNumElements())); 129 Value allocated = rewriter.create<LLVM::AllocaOp>( 130 gpuFuncOp.getLoc(), ptrType, numElements, /*alignment=*/0); 131 auto descr = MemRefDescriptor::fromStaticShape( 132 rewriter, loc, *getTypeConverter(), type, allocated); 133 signatureConversion.remapInput( 134 numProperArguments + numWorkgroupAttributions + en.index(), descr); 135 } 136 } 137 138 // Move the region to the new function, update the entry block signature. 139 rewriter.inlineRegionBefore(gpuFuncOp.getBody(), llvmFuncOp.getBody(), 140 llvmFuncOp.end()); 141 if (failed(rewriter.convertRegionTypes(&llvmFuncOp.getBody(), *typeConverter, 142 &signatureConversion))) 143 return failure(); 144 145 rewriter.eraseOp(gpuFuncOp); 146 return success(); 147 } 148 149 static const char formatStringPrefix[] = "printfFormat_"; 150 151 template <typename T> 152 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc, 153 ConversionPatternRewriter &rewriter, 154 StringRef name, 155 LLVM::LLVMFunctionType type) { 156 LLVM::LLVMFuncOp ret; 157 if (!(ret = moduleOp.template lookupSymbol<LLVM::LLVMFuncOp>(name))) { 158 ConversionPatternRewriter::InsertionGuard guard(rewriter); 159 rewriter.setInsertionPointToStart(moduleOp.getBody()); 160 ret = rewriter.create<LLVM::LLVMFuncOp>(loc, name, type, 161 LLVM::Linkage::External); 162 } 163 return ret; 164 } 165 166 LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite( 167 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, 168 ConversionPatternRewriter &rewriter) const { 169 Location loc = gpuPrintfOp->getLoc(); 170 171 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getI8Type()); 172 mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8); 173 mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType()); 174 mlir::Type llvmI32 = typeConverter->convertType(rewriter.getI32Type()); 175 mlir::Type llvmI64 = typeConverter->convertType(rewriter.getI64Type()); 176 // Note: this is the GPUModule op, not the ModuleOp that surrounds it 177 // This ensures that global constants and declarations are placed within 178 // the device code, not the host code 179 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); 180 181 auto ocklBegin = 182 getOrDefineFunction(moduleOp, loc, rewriter, "__ockl_printf_begin", 183 LLVM::LLVMFunctionType::get(llvmI64, {llvmI64})); 184 LLVM::LLVMFuncOp ocklAppendArgs; 185 if (!adaptor.args().empty()) { 186 ocklAppendArgs = getOrDefineFunction( 187 moduleOp, loc, rewriter, "__ockl_printf_append_args", 188 LLVM::LLVMFunctionType::get( 189 llvmI64, {llvmI64, /*numArgs*/ llvmI32, llvmI64, llvmI64, llvmI64, 190 llvmI64, llvmI64, llvmI64, llvmI64, /*isLast*/ llvmI32})); 191 } 192 auto ocklAppendStringN = getOrDefineFunction( 193 moduleOp, loc, rewriter, "__ockl_printf_append_string_n", 194 LLVM::LLVMFunctionType::get( 195 llvmI64, 196 {llvmI64, i8Ptr, /*length (bytes)*/ llvmI64, /*isLast*/ llvmI32})); 197 198 /// Start the printf hostcall 199 Value zeroI64 = rewriter.create<LLVM::ConstantOp>( 200 loc, llvmI64, rewriter.getI64IntegerAttr(0)); 201 auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64); 202 Value printfDesc = printfBeginCall.getResult(0); 203 204 // Create a global constant for the format string 205 unsigned stringNumber = 0; 206 SmallString<16> stringConstName; 207 do { 208 stringConstName.clear(); 209 (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); 210 } while (moduleOp.lookupSymbol(stringConstName)); 211 212 llvm::SmallString<20> formatString(adaptor.format()); 213 formatString.push_back('\0'); // Null terminate for C 214 size_t formatStringSize = formatString.size_in_bytes(); 215 216 auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize); 217 LLVM::GlobalOp global; 218 { 219 ConversionPatternRewriter::InsertionGuard guard(rewriter); 220 rewriter.setInsertionPointToStart(moduleOp.getBody()); 221 global = rewriter.create<LLVM::GlobalOp>( 222 loc, globalType, 223 /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, 224 rewriter.getStringAttr(formatString)); 225 } 226 227 // Get a pointer to the format string's first element and pass it to printf() 228 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); 229 Value zero = rewriter.create<LLVM::ConstantOp>( 230 loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0)); 231 Value stringStart = rewriter.create<LLVM::GEPOp>( 232 loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero})); 233 Value stringLen = rewriter.create<LLVM::ConstantOp>( 234 loc, llvmI64, rewriter.getI64IntegerAttr(formatStringSize)); 235 236 Value oneI32 = rewriter.create<LLVM::ConstantOp>( 237 loc, llvmI32, rewriter.getI32IntegerAttr(1)); 238 Value zeroI32 = rewriter.create<LLVM::ConstantOp>( 239 loc, llvmI32, rewriter.getI32IntegerAttr(0)); 240 241 auto appendFormatCall = rewriter.create<LLVM::CallOp>( 242 loc, ocklAppendStringN, 243 ValueRange{printfDesc, stringStart, stringLen, 244 adaptor.args().empty() ? oneI32 : zeroI32}); 245 printfDesc = appendFormatCall.getResult(0); 246 247 // __ockl_printf_append_args takes 7 values per append call 248 constexpr size_t argsPerAppend = 7; 249 size_t nArgs = adaptor.args().size(); 250 for (size_t group = 0; group < nArgs; group += argsPerAppend) { 251 size_t bound = std::min(group + argsPerAppend, nArgs); 252 size_t numArgsThisCall = bound - group; 253 254 SmallVector<mlir::Value, 2 + argsPerAppend + 1> arguments; 255 arguments.push_back(printfDesc); 256 arguments.push_back(rewriter.create<LLVM::ConstantOp>( 257 loc, llvmI32, rewriter.getI32IntegerAttr(numArgsThisCall))); 258 for (size_t i = group; i < bound; ++i) { 259 Value arg = adaptor.args()[i]; 260 if (auto floatType = arg.getType().dyn_cast<FloatType>()) { 261 if (!floatType.isF64()) 262 arg = rewriter.create<LLVM::FPExtOp>( 263 loc, typeConverter->convertType(rewriter.getF64Type()), arg); 264 arg = rewriter.create<LLVM::BitcastOp>(loc, llvmI64, arg); 265 } 266 if (arg.getType().getIntOrFloatBitWidth() != 64) 267 arg = rewriter.create<LLVM::ZExtOp>(loc, llvmI64, arg); 268 269 arguments.push_back(arg); 270 } 271 // Pad out to 7 arguments since the hostcall always needs 7 272 for (size_t extra = numArgsThisCall; extra < argsPerAppend; ++extra) { 273 arguments.push_back(zeroI64); 274 } 275 276 auto isLast = (bound == nArgs) ? oneI32 : zeroI32; 277 arguments.push_back(isLast); 278 auto call = rewriter.create<LLVM::CallOp>(loc, ocklAppendArgs, arguments); 279 printfDesc = call.getResult(0); 280 } 281 rewriter.eraseOp(gpuPrintfOp); 282 return success(); 283 } 284 285 LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite( 286 gpu::PrintfOp gpuPrintfOp, gpu::PrintfOpAdaptor adaptor, 287 ConversionPatternRewriter &rewriter) const { 288 Location loc = gpuPrintfOp->getLoc(); 289 290 mlir::Type llvmI8 = typeConverter->convertType(rewriter.getIntegerType(8)); 291 mlir::Type i8Ptr = LLVM::LLVMPointerType::get(llvmI8, addressSpace); 292 mlir::Type llvmIndex = typeConverter->convertType(rewriter.getIndexType()); 293 294 // Note: this is the GPUModule op, not the ModuleOp that surrounds it 295 // This ensures that global constants and declarations are placed within 296 // the device code, not the host code 297 auto moduleOp = gpuPrintfOp->getParentOfType<gpu::GPUModuleOp>(); 298 299 auto printfType = LLVM::LLVMFunctionType::get(rewriter.getI32Type(), {i8Ptr}, 300 /*isVarArg=*/true); 301 LLVM::LLVMFuncOp printfDecl = 302 getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType); 303 304 // Create a global constant for the format string 305 unsigned stringNumber = 0; 306 SmallString<16> stringConstName; 307 do { 308 stringConstName.clear(); 309 (formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName); 310 } while (moduleOp.lookupSymbol(stringConstName)); 311 312 llvm::SmallString<20> formatString(adaptor.format()); 313 formatString.push_back('\0'); // Null terminate for C 314 auto globalType = 315 LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes()); 316 LLVM::GlobalOp global; 317 { 318 ConversionPatternRewriter::InsertionGuard guard(rewriter); 319 rewriter.setInsertionPointToStart(moduleOp.getBody()); 320 global = rewriter.create<LLVM::GlobalOp>( 321 loc, globalType, 322 /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName, 323 rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace); 324 } 325 326 // Get a pointer to the format string's first element 327 Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global); 328 Value zero = rewriter.create<LLVM::ConstantOp>( 329 loc, llvmIndex, rewriter.getIntegerAttr(llvmIndex, 0)); 330 Value stringStart = rewriter.create<LLVM::GEPOp>( 331 loc, i8Ptr, globalPtr, mlir::ValueRange({zero, zero})); 332 333 // Construct arguments and function call 334 auto argsRange = adaptor.args(); 335 SmallVector<Value, 4> printfArgs; 336 printfArgs.reserve(argsRange.size() + 1); 337 printfArgs.push_back(stringStart); 338 printfArgs.append(argsRange.begin(), argsRange.end()); 339 340 rewriter.create<LLVM::CallOp>(loc, printfDecl, printfArgs); 341 rewriter.eraseOp(gpuPrintfOp); 342 return success(); 343 } 344