1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert gpu.launch_func op into a sequence of 10 // GPU runtime calls. As most of GPU runtimes does not have a stable published 11 // ABI, this pass uses a slim runtime layer that builds on top of the public 12 // API from GPU runtime headers. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 17 18 #include "../PassDetail.h" 19 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 20 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 21 #include "mlir/Conversion/LLVMCommon/Pattern.h" 22 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 23 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 24 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 25 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 26 #include "mlir/Dialect/Async/IR/Async.h" 27 #include "mlir/Dialect/GPU/GPUDialect.h" 28 #include "mlir/Dialect/GPU/Passes.h" 29 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 30 #include "mlir/IR/Attributes.h" 31 #include "mlir/IR/Builders.h" 32 #include "mlir/IR/BuiltinOps.h" 33 #include "mlir/IR/BuiltinTypes.h" 34 35 #include "llvm/ADT/STLExtras.h" 36 #include "llvm/Support/Error.h" 37 #include "llvm/Support/FormatVariadic.h" 38 39 using namespace mlir; 40 41 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst"; 42 43 namespace { 44 45 class GpuToLLVMConversionPass 46 : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> { 47 public: 48 GpuToLLVMConversionPass() = default; 49 50 GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other) 51 : GpuToLLVMConversionPassBase(other) {} 52 53 // Run the dialect converter on the module. 54 void runOnOperation() override; 55 56 private: 57 Option<std::string> gpuBinaryAnnotation{ 58 *this, "gpu-binary-annotation", 59 llvm::cl::desc("Annotation attribute string for GPU binary"), 60 llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())}; 61 }; 62 63 struct FunctionCallBuilder { 64 FunctionCallBuilder(StringRef functionName, Type returnType, 65 ArrayRef<Type> argumentTypes) 66 : functionName(functionName), 67 functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {} 68 LLVM::CallOp create(Location loc, OpBuilder &builder, 69 ArrayRef<Value> arguments) const; 70 71 StringRef functionName; 72 LLVM::LLVMFunctionType functionType; 73 }; 74 75 template <typename OpTy> 76 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> { 77 public: 78 explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 79 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {} 80 81 protected: 82 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc, 83 MemRefType type, MemRefDescriptor desc) const { 84 return type.hasStaticShape() 85 ? ConvertToLLVMPattern::createIndexConstant( 86 rewriter, loc, type.getNumElements()) 87 // For identity maps (verified by caller), the number of 88 // elements is stride[0] * size[0]. 89 : rewriter.create<LLVM::MulOp>(loc, 90 desc.stride(rewriter, loc, 0), 91 desc.size(rewriter, loc, 0)); 92 } 93 94 MLIRContext *context = &this->getTypeConverter()->getContext(); 95 96 Type llvmVoidType = LLVM::LLVMVoidType::get(context); 97 Type llvmPointerType = 98 LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 99 Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType); 100 Type llvmInt8Type = IntegerType::get(context, 8); 101 Type llvmInt32Type = IntegerType::get(context, 32); 102 Type llvmInt64Type = IntegerType::get(context, 64); 103 Type llvmIntPtrType = IntegerType::get( 104 context, this->getTypeConverter()->getPointerBitwidth(0)); 105 106 FunctionCallBuilder moduleLoadCallBuilder = { 107 "mgpuModuleLoad", 108 llvmPointerType /* void *module */, 109 {llvmPointerType /* void *cubin */}}; 110 FunctionCallBuilder moduleUnloadCallBuilder = { 111 "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}}; 112 FunctionCallBuilder moduleGetFunctionCallBuilder = { 113 "mgpuModuleGetFunction", 114 llvmPointerType /* void *function */, 115 { 116 llvmPointerType, /* void *module */ 117 llvmPointerType /* char *name */ 118 }}; 119 FunctionCallBuilder launchKernelCallBuilder = { 120 "mgpuLaunchKernel", 121 llvmVoidType, 122 { 123 llvmPointerType, /* void* f */ 124 llvmIntPtrType, /* intptr_t gridXDim */ 125 llvmIntPtrType, /* intptr_t gridyDim */ 126 llvmIntPtrType, /* intptr_t gridZDim */ 127 llvmIntPtrType, /* intptr_t blockXDim */ 128 llvmIntPtrType, /* intptr_t blockYDim */ 129 llvmIntPtrType, /* intptr_t blockZDim */ 130 llvmInt32Type, /* unsigned int sharedMemBytes */ 131 llvmPointerType, /* void *hstream */ 132 llvmPointerPointerType, /* void **kernelParams */ 133 llvmPointerPointerType /* void **extra */ 134 }}; 135 FunctionCallBuilder streamCreateCallBuilder = { 136 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}}; 137 FunctionCallBuilder streamDestroyCallBuilder = { 138 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}}; 139 FunctionCallBuilder streamSynchronizeCallBuilder = { 140 "mgpuStreamSynchronize", 141 llvmVoidType, 142 {llvmPointerType /* void *stream */}}; 143 FunctionCallBuilder streamWaitEventCallBuilder = { 144 "mgpuStreamWaitEvent", 145 llvmVoidType, 146 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}}; 147 FunctionCallBuilder eventCreateCallBuilder = { 148 "mgpuEventCreate", llvmPointerType /* void *event */, {}}; 149 FunctionCallBuilder eventDestroyCallBuilder = { 150 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}}; 151 FunctionCallBuilder eventSynchronizeCallBuilder = { 152 "mgpuEventSynchronize", 153 llvmVoidType, 154 {llvmPointerType /* void *event */}}; 155 FunctionCallBuilder eventRecordCallBuilder = { 156 "mgpuEventRecord", 157 llvmVoidType, 158 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}}; 159 FunctionCallBuilder hostRegisterCallBuilder = { 160 "mgpuMemHostRegisterMemRef", 161 llvmVoidType, 162 {llvmIntPtrType /* intptr_t rank */, 163 llvmPointerType /* void *memrefDesc */, 164 llvmIntPtrType /* intptr_t elementSizeBytes */}}; 165 FunctionCallBuilder allocCallBuilder = { 166 "mgpuMemAlloc", 167 llvmPointerType /* void * */, 168 {llvmIntPtrType /* intptr_t sizeBytes */, 169 llvmPointerType /* void *stream */}}; 170 FunctionCallBuilder deallocCallBuilder = { 171 "mgpuMemFree", 172 llvmVoidType, 173 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}}; 174 FunctionCallBuilder memcpyCallBuilder = { 175 "mgpuMemcpy", 176 llvmVoidType, 177 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */, 178 llvmIntPtrType /* intptr_t sizeBytes */, 179 llvmPointerType /* void *stream */}}; 180 FunctionCallBuilder memsetCallBuilder = { 181 "mgpuMemset32", 182 llvmVoidType, 183 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */, 184 llvmIntPtrType /* intptr_t sizeBytes */, 185 llvmPointerType /* void *stream */}}; 186 }; 187 188 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime 189 /// call. Currently it supports CUDA and ROCm (HIP). 190 class ConvertHostRegisterOpToGpuRuntimeCallPattern 191 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> { 192 public: 193 ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 194 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {} 195 196 private: 197 LogicalResult 198 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, 199 ConversionPatternRewriter &rewriter) const override; 200 }; 201 202 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime 203 /// call. Currently it supports CUDA and ROCm (HIP). 204 class ConvertAllocOpToGpuRuntimeCallPattern 205 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> { 206 public: 207 ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 208 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {} 209 210 private: 211 LogicalResult 212 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor, 213 ConversionPatternRewriter &rewriter) const override; 214 }; 215 216 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime 217 /// call. Currently it supports CUDA and ROCm (HIP). 218 class ConvertDeallocOpToGpuRuntimeCallPattern 219 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> { 220 public: 221 ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 222 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {} 223 224 private: 225 LogicalResult 226 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor, 227 ConversionPatternRewriter &rewriter) const override; 228 }; 229 230 class ConvertAsyncYieldToGpuRuntimeCallPattern 231 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> { 232 public: 233 ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 234 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {} 235 236 private: 237 LogicalResult 238 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor, 239 ConversionPatternRewriter &rewriter) const override; 240 }; 241 242 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime 243 /// call. Currently it supports CUDA and ROCm (HIP). 244 class ConvertWaitOpToGpuRuntimeCallPattern 245 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 246 public: 247 ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 248 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 249 250 private: 251 LogicalResult 252 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, 253 ConversionPatternRewriter &rewriter) const override; 254 }; 255 256 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime 257 /// call. Currently it supports CUDA and ROCm (HIP). 258 class ConvertWaitAsyncOpToGpuRuntimeCallPattern 259 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 260 public: 261 ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 262 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 263 264 private: 265 LogicalResult 266 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, 267 ConversionPatternRewriter &rewriter) const override; 268 }; 269 270 /// A rewrite patter to convert gpu.launch_func operations into a sequence of 271 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP). 272 /// 273 /// In essence, a gpu.launch_func operations gets compiled into the following 274 /// sequence of runtime calls: 275 /// 276 /// * moduleLoad -- loads the module given the cubin / hsaco data 277 /// * moduleGetFunction -- gets a handle to the actual kernel function 278 /// * getStreamHelper -- initializes a new compute stream on GPU 279 /// * launchKernel -- launches the kernel on a stream 280 /// * streamSynchronize -- waits for operations on the stream to finish 281 /// 282 /// Intermediate data structures are allocated on the stack. 283 class ConvertLaunchFuncOpToGpuRuntimeCallPattern 284 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> { 285 public: 286 ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter, 287 StringRef gpuBinaryAnnotation) 288 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter), 289 gpuBinaryAnnotation(gpuBinaryAnnotation) {} 290 291 private: 292 Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 293 OpBuilder &builder) const; 294 Value generateKernelNameConstant(StringRef moduleName, StringRef name, 295 Location loc, OpBuilder &builder) const; 296 297 LogicalResult 298 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 299 ConversionPatternRewriter &rewriter) const override; 300 301 llvm::SmallString<32> gpuBinaryAnnotation; 302 }; 303 304 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> { 305 using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern; 306 307 LogicalResult matchAndRewrite(gpu::GPUModuleOp op, 308 PatternRewriter &rewriter) const override { 309 // GPU kernel modules are no longer necessary since we have a global 310 // constant with the CUBIN, or HSACO data. 311 rewriter.eraseOp(op); 312 return success(); 313 } 314 }; 315 316 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime 317 /// call. Currently it supports CUDA and ROCm (HIP). 318 class ConvertMemcpyOpToGpuRuntimeCallPattern 319 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> { 320 public: 321 ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 322 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {} 323 324 private: 325 LogicalResult 326 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, 327 ConversionPatternRewriter &rewriter) const override; 328 }; 329 330 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime 331 /// call. Currently it supports CUDA and ROCm (HIP). 332 class ConvertMemsetOpToGpuRuntimeCallPattern 333 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> { 334 public: 335 ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 336 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {} 337 338 private: 339 LogicalResult 340 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor, 341 ConversionPatternRewriter &rewriter) const override; 342 }; 343 } // namespace 344 345 void GpuToLLVMConversionPass::runOnOperation() { 346 LLVMTypeConverter converter(&getContext()); 347 RewritePatternSet patterns(&getContext()); 348 LLVMConversionTarget target(getContext()); 349 350 target.addIllegalDialect<gpu::GPUDialect>(); 351 352 populateVectorToLLVMConversionPatterns(converter, patterns); 353 populateMemRefToLLVMConversionPatterns(converter, patterns); 354 populateStdToLLVMConversionPatterns(converter, patterns); 355 populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, 356 target); 357 populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation); 358 359 if (failed( 360 applyPartialConversion(getOperation(), target, std::move(patterns)))) 361 signalPassFailure(); 362 } 363 364 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, 365 ArrayRef<Value> arguments) const { 366 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>(); 367 auto function = [&] { 368 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) 369 return function; 370 return OpBuilder::atBlockEnd(module.getBody()) 371 .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); 372 }(); 373 return builder.create<LLVM::CallOp>(loc, function, arguments); 374 } 375 376 // Returns whether all operands are of LLVM type. 377 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, 378 ConversionPatternRewriter &rewriter) { 379 if (!llvm::all_of(operands, [](Value value) { 380 return LLVM::isCompatibleType(value.getType()); 381 })) 382 return rewriter.notifyMatchFailure( 383 op, "Cannot convert if operands aren't of LLVM type."); 384 return success(); 385 } 386 387 static LogicalResult 388 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, 389 gpu::AsyncOpInterface op) { 390 if (op.getAsyncDependencies().size() != 1) 391 return rewriter.notifyMatchFailure( 392 op, "Can only convert with exactly one async dependency."); 393 394 if (!op.getAsyncToken()) 395 return rewriter.notifyMatchFailure(op, "Can convert only async version."); 396 397 return success(); 398 } 399 400 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( 401 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, 402 ConversionPatternRewriter &rewriter) const { 403 auto *op = hostRegisterOp.getOperation(); 404 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 405 return failure(); 406 407 Location loc = op->getLoc(); 408 409 auto memRefType = hostRegisterOp.value().getType(); 410 auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType(); 411 auto elementSize = getSizeInBytes(loc, elementType, rewriter); 412 413 auto arguments = getTypeConverter()->promoteOperands( 414 loc, op->getOperands(), adaptor.getOperands(), rewriter); 415 arguments.push_back(elementSize); 416 hostRegisterCallBuilder.create(loc, rewriter, arguments); 417 418 rewriter.eraseOp(op); 419 return success(); 420 } 421 422 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( 423 gpu::AllocOp allocOp, OpAdaptor adaptor, 424 ConversionPatternRewriter &rewriter) const { 425 MemRefType memRefType = allocOp.getType(); 426 427 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) || 428 !isConvertibleAndHasIdentityMaps(memRefType) || 429 failed(isAsyncWithOneDependency(rewriter, allocOp))) 430 return failure(); 431 432 auto loc = allocOp.getLoc(); 433 434 // Get shape of the memref as values: static sizes are constant 435 // values and dynamic sizes are passed to 'alloc' as operands. 436 SmallVector<Value, 4> shape; 437 SmallVector<Value, 4> strides; 438 Value sizeBytes; 439 getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter, 440 shape, strides, sizeBytes); 441 442 // Allocate the underlying buffer and store a pointer to it in the MemRef 443 // descriptor. 444 Type elementPtrType = this->getElementPtrType(memRefType); 445 auto stream = adaptor.asyncDependencies().front(); 446 Value allocatedPtr = 447 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0); 448 allocatedPtr = 449 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr); 450 451 // No alignment. 452 Value alignedPtr = allocatedPtr; 453 454 // Create the MemRef descriptor. 455 auto memRefDescriptor = this->createMemRefDescriptor( 456 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); 457 458 rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); 459 460 return success(); 461 } 462 463 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( 464 gpu::DeallocOp deallocOp, OpAdaptor adaptor, 465 ConversionPatternRewriter &rewriter) const { 466 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) || 467 failed(isAsyncWithOneDependency(rewriter, deallocOp))) 468 return failure(); 469 470 Location loc = deallocOp.getLoc(); 471 472 Value pointer = 473 MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc); 474 auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer); 475 Value stream = adaptor.asyncDependencies().front(); 476 deallocCallBuilder.create(loc, rewriter, {casted, stream}); 477 478 rewriter.replaceOp(deallocOp, {stream}); 479 return success(); 480 } 481 482 static bool isGpuAsyncTokenType(Value value) { 483 return value.getType().isa<gpu::AsyncTokenType>(); 484 } 485 486 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The 487 // !gpu.async.token are lowered to stream within the async.execute region, but 488 // are passed as events between them. For each !gpu.async.token operand, we 489 // create an event and record it on the stream. 490 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( 491 async::YieldOp yieldOp, OpAdaptor adaptor, 492 ConversionPatternRewriter &rewriter) const { 493 if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType)) 494 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand"); 495 496 Location loc = yieldOp.getLoc(); 497 SmallVector<Value, 4> newOperands(adaptor.getOperands()); 498 llvm::SmallDenseSet<Value> streams; 499 for (auto &operand : yieldOp->getOpOperands()) { 500 if (!isGpuAsyncTokenType(operand.get())) 501 continue; 502 auto idx = operand.getOperandNumber(); 503 auto stream = adaptor.getOperands()[idx]; 504 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 505 eventRecordCallBuilder.create(loc, rewriter, {event, stream}); 506 newOperands[idx] = event; 507 streams.insert(stream); 508 } 509 for (auto stream : streams) 510 streamDestroyCallBuilder.create(loc, rewriter, {stream}); 511 512 rewriter.updateRootInPlace(yieldOp, 513 [&] { yieldOp->setOperands(newOperands); }); 514 return success(); 515 } 516 517 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`. 518 static bool isDefinedByCallTo(Value value, StringRef functionName) { 519 assert(value.getType().isa<LLVM::LLVMPointerType>()); 520 if (auto defOp = value.getDefiningOp<LLVM::CallOp>()) 521 return defOp.callee()->equals(functionName); 522 return false; 523 } 524 525 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host 526 // with the stream/event operands. The operands are destroyed. That is, it 527 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a 528 // runtime error. Eventually, we should guarantee this property. 529 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( 530 gpu::WaitOp waitOp, OpAdaptor adaptor, 531 ConversionPatternRewriter &rewriter) const { 532 if (waitOp.asyncToken()) 533 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op."); 534 535 Location loc = waitOp.getLoc(); 536 537 for (auto operand : adaptor.getOperands()) { 538 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 539 // The converted operand's definition created a stream. 540 streamSynchronizeCallBuilder.create(loc, rewriter, {operand}); 541 streamDestroyCallBuilder.create(loc, rewriter, {operand}); 542 } else { 543 // Otherwise the converted operand is an event. This assumes that we use 544 // events in control flow code as well. 545 eventSynchronizeCallBuilder.create(loc, rewriter, {operand}); 546 eventDestroyCallBuilder.create(loc, rewriter, {operand}); 547 } 548 } 549 550 rewriter.eraseOp(waitOp); 551 return success(); 552 } 553 554 // Converts `gpu.wait async` to runtime calls. The converted op creates a new 555 // stream that is synchronized with stream/event operands. The operands are 556 // destroyed. That is, it assumes that it is not used afterwards or elsewhere. 557 // Otherwise we will get a runtime error. Eventually, we should guarantee this 558 // property. 559 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( 560 gpu::WaitOp waitOp, OpAdaptor adaptor, 561 ConversionPatternRewriter &rewriter) const { 562 if (!waitOp.asyncToken()) 563 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op."); 564 565 Location loc = waitOp.getLoc(); 566 567 auto insertionPoint = rewriter.saveInsertionPoint(); 568 SmallVector<Value, 1> events; 569 for (auto pair : 570 llvm::zip(waitOp.asyncDependencies(), adaptor.getOperands())) { 571 auto operand = std::get<1>(pair); 572 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 573 // The converted operand's definition created a stream. Insert an event 574 // into the stream just after the last use of the original token operand. 575 auto *defOp = std::get<0>(pair).getDefiningOp(); 576 rewriter.setInsertionPointAfter(defOp); 577 auto event = 578 eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 579 eventRecordCallBuilder.create(loc, rewriter, {event, operand}); 580 events.push_back(event); 581 } else { 582 // Otherwise the converted operand is an event. This assumes that we use 583 // events in control flow code as well. 584 events.push_back(operand); 585 } 586 } 587 rewriter.restoreInsertionPoint(insertionPoint); 588 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 589 for (auto event : events) 590 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event}); 591 for (auto event : events) 592 eventDestroyCallBuilder.create(loc, rewriter, {event}); 593 rewriter.replaceOp(waitOp, {stream}); 594 595 return success(); 596 } 597 598 // Creates a struct containing all kernel parameters on the stack and returns 599 // an array of type-erased pointers to the fields of the struct. The array can 600 // then be passed to the CUDA / ROCm (HIP) kernel launch calls. 601 // The generated code is essentially as follows: 602 // 603 // %struct = alloca(sizeof(struct { Parameters... })) 604 // %array = alloca(NumParameters * sizeof(void *)) 605 // for (i : [0, NumParameters)) 606 // %fieldPtr = llvm.getelementptr %struct[0, i] 607 // llvm.store parameters[i], %fieldPtr 608 // %elementPtr = llvm.getelementptr %array[i] 609 // llvm.store %fieldPtr, %elementPtr 610 // return %array 611 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( 612 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const { 613 auto loc = launchOp.getLoc(); 614 auto numKernelOperands = launchOp.getNumKernelOperands(); 615 auto arguments = getTypeConverter()->promoteOperands( 616 loc, launchOp.getOperands().take_back(numKernelOperands), 617 adaptor.getOperands().take_back(numKernelOperands), builder); 618 auto numArguments = arguments.size(); 619 SmallVector<Type, 4> argumentTypes; 620 argumentTypes.reserve(numArguments); 621 for (auto argument : arguments) 622 argumentTypes.push_back(argument.getType()); 623 auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(), 624 argumentTypes); 625 auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 626 builder.getI32IntegerAttr(1)); 627 auto structPtr = builder.create<LLVM::AllocaOp>( 628 loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0); 629 auto arraySize = builder.create<LLVM::ConstantOp>( 630 loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments)); 631 auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType, 632 arraySize, /*alignment=*/0); 633 auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 634 builder.getI32IntegerAttr(0)); 635 for (auto en : llvm::enumerate(arguments)) { 636 auto index = builder.create<LLVM::ConstantOp>( 637 loc, llvmInt32Type, builder.getI32IntegerAttr(en.index())); 638 auto fieldPtr = builder.create<LLVM::GEPOp>( 639 loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr, 640 ArrayRef<Value>{zero, index.getResult()}); 641 builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr); 642 auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType, 643 arrayPtr, index.getResult()); 644 auto casted = 645 builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr); 646 builder.create<LLVM::StoreOp>(loc, casted, elementPtr); 647 } 648 return arrayPtr; 649 } 650 651 // Generates an LLVM IR dialect global that contains the name of the given 652 // kernel function as a C string, and returns a pointer to its beginning. 653 // The code is essentially: 654 // 655 // llvm.global constant @kernel_name("function_name\00") 656 // func(...) { 657 // %0 = llvm.addressof @kernel_name 658 // %1 = llvm.constant (0 : index) 659 // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*"> 660 // } 661 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant( 662 StringRef moduleName, StringRef name, Location loc, 663 OpBuilder &builder) const { 664 // Make sure the trailing zero is included in the constant. 665 std::vector<char> kernelName(name.begin(), name.end()); 666 kernelName.push_back('\0'); 667 668 std::string globalName = 669 std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name)); 670 return LLVM::createGlobalString( 671 loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()), 672 LLVM::Linkage::Internal); 673 } 674 675 // Emits LLVM IR to launch a kernel function. Expects the module that contains 676 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a 677 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR. 678 // 679 // %0 = call %binarygetter 680 // %1 = call %moduleLoad(%0) 681 // %2 = <see generateKernelNameConstant> 682 // %3 = call %moduleGetFunction(%1, %2) 683 // %4 = call %streamCreate() 684 // %5 = <see generateParamsArray> 685 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr) 686 // call %streamSynchronize(%4) 687 // call %streamDestroy(%4) 688 // call %moduleUnload(%1) 689 // 690 // If the op is async, the stream corresponds to the (single) async dependency 691 // as well as the async token the op produces. 692 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( 693 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 694 ConversionPatternRewriter &rewriter) const { 695 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter))) 696 return failure(); 697 698 if (launchOp.asyncDependencies().size() > 1) 699 return rewriter.notifyMatchFailure( 700 launchOp, "Cannot convert with more than one async dependency."); 701 702 // Fail when the synchronous version of the op has async dependencies. The 703 // lowering destroys the stream, and we do not want to check that there is no 704 // use of the stream after this op. 705 if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty()) 706 return rewriter.notifyMatchFailure( 707 launchOp, "Cannot convert non-async op with async dependencies."); 708 709 Location loc = launchOp.getLoc(); 710 711 // Create an LLVM global with CUBIN extracted from the kernel annotation and 712 // obtain a pointer to the first byte in it. 713 auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>( 714 launchOp, launchOp.getKernelModuleName()); 715 assert(kernelModule && "expected a kernel module"); 716 717 auto binaryAttr = 718 kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation); 719 if (!binaryAttr) { 720 kernelModule.emitOpError() 721 << "missing " << gpuBinaryAnnotation << " attribute"; 722 return failure(); 723 } 724 725 SmallString<128> nameBuffer(kernelModule.getName()); 726 nameBuffer.append(kGpuBinaryStorageSuffix); 727 Value data = 728 LLVM::createGlobalString(loc, rewriter, nameBuffer.str(), 729 binaryAttr.getValue(), LLVM::Linkage::Internal); 730 731 auto module = moduleLoadCallBuilder.create(loc, rewriter, data); 732 // Get the function from the module. The name corresponds to the name of 733 // the kernel function. 734 auto kernelName = generateKernelNameConstant( 735 launchOp.getKernelModuleName().getValue(), 736 launchOp.getKernelName().getValue(), loc, rewriter); 737 auto function = moduleGetFunctionCallBuilder.create( 738 loc, rewriter, {module.getResult(0), kernelName}); 739 auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 740 rewriter.getI32IntegerAttr(0)); 741 Value stream = 742 adaptor.asyncDependencies().empty() 743 ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0) 744 : adaptor.asyncDependencies().front(); 745 // Create array of pointers to kernel arguments. 746 auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter); 747 auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType); 748 launchKernelCallBuilder.create(loc, rewriter, 749 {function.getResult(0), adaptor.gridSizeX(), 750 adaptor.gridSizeY(), adaptor.gridSizeZ(), 751 adaptor.blockSizeX(), adaptor.blockSizeY(), 752 adaptor.blockSizeZ(), 753 /*sharedMemBytes=*/zero, stream, kernelParams, 754 /*extra=*/nullpointer}); 755 756 if (launchOp.asyncToken()) { 757 // Async launch: make dependent ops use the same stream. 758 rewriter.replaceOp(launchOp, {stream}); 759 } else { 760 // Synchronize with host and destroy stream. This must be the stream created 761 // above (with no other uses) because we check that the synchronous version 762 // does not have any async dependencies. 763 streamSynchronizeCallBuilder.create(loc, rewriter, stream); 764 streamDestroyCallBuilder.create(loc, rewriter, stream); 765 rewriter.eraseOp(launchOp); 766 } 767 moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0)); 768 769 return success(); 770 } 771 772 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( 773 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, 774 ConversionPatternRewriter &rewriter) const { 775 auto memRefType = memcpyOp.src().getType().cast<MemRefType>(); 776 777 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) || 778 !isConvertibleAndHasIdentityMaps(memRefType) || 779 failed(isAsyncWithOneDependency(rewriter, memcpyOp))) 780 return failure(); 781 782 auto loc = memcpyOp.getLoc(); 783 784 MemRefDescriptor srcDesc(adaptor.src()); 785 Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); 786 787 Type elementPtrType = getElementPtrType(memRefType); 788 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType); 789 Value gepPtr = rewriter.create<LLVM::GEPOp>( 790 loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements}); 791 auto sizeBytes = 792 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); 793 794 auto src = rewriter.create<LLVM::BitcastOp>( 795 loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc)); 796 auto dst = rewriter.create<LLVM::BitcastOp>( 797 loc, llvmPointerType, 798 MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc)); 799 800 auto stream = adaptor.asyncDependencies().front(); 801 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream}); 802 803 rewriter.replaceOp(memcpyOp, {stream}); 804 805 return success(); 806 } 807 808 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( 809 gpu::MemsetOp memsetOp, OpAdaptor adaptor, 810 ConversionPatternRewriter &rewriter) const { 811 auto memRefType = memsetOp.dst().getType().cast<MemRefType>(); 812 813 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) || 814 !isConvertibleAndHasIdentityMaps(memRefType) || 815 failed(isAsyncWithOneDependency(rewriter, memsetOp))) 816 return failure(); 817 818 auto loc = memsetOp.getLoc(); 819 820 Type valueType = adaptor.value().getType(); 821 if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) { 822 return rewriter.notifyMatchFailure(memsetOp, 823 "value must be a 32 bit scalar"); 824 } 825 826 MemRefDescriptor dstDesc(adaptor.dst()); 827 Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); 828 829 auto value = 830 rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value()); 831 auto dst = rewriter.create<LLVM::BitcastOp>( 832 loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc)); 833 834 auto stream = adaptor.asyncDependencies().front(); 835 memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream}); 836 837 rewriter.replaceOp(memsetOp, {stream}); 838 return success(); 839 } 840 841 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 842 mlir::createGpuToLLVMConversionPass() { 843 return std::make_unique<GpuToLLVMConversionPass>(); 844 } 845 846 void mlir::populateGpuToLLVMConversionPatterns( 847 LLVMTypeConverter &converter, OwningRewritePatternList &patterns, 848 StringRef gpuBinaryAnnotation) { 849 converter.addConversion( 850 [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type { 851 return LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 852 }); 853 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern, 854 ConvertDeallocOpToGpuRuntimeCallPattern, 855 ConvertHostRegisterOpToGpuRuntimeCallPattern, 856 ConvertMemcpyOpToGpuRuntimeCallPattern, 857 ConvertMemsetOpToGpuRuntimeCallPattern, 858 ConvertWaitAsyncOpToGpuRuntimeCallPattern, 859 ConvertWaitOpToGpuRuntimeCallPattern, 860 ConvertAsyncYieldToGpuRuntimeCallPattern>(converter); 861 patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter, 862 gpuBinaryAnnotation); 863 patterns.add<EraseGpuModuleOpPattern>(&converter.getContext()); 864 } 865