1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert gpu.launch_func op into a sequence of 10 // GPU runtime calls. As most of GPU runtimes does not have a stable published 11 // ABI, this pass uses a slim runtime layer that builds on top of the public 12 // API from GPU runtime headers. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 17 18 #include "../PassDetail.h" 19 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h" 20 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 21 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" 22 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" 23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" 24 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" 25 #include "mlir/Conversion/LLVMCommon/Pattern.h" 26 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h" 27 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 28 #include "mlir/Dialect/Async/IR/Async.h" 29 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 30 #include "mlir/Dialect/GPU/Transforms/Passes.h" 31 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 32 #include "mlir/IR/Attributes.h" 33 #include "mlir/IR/Builders.h" 34 #include "mlir/IR/BuiltinOps.h" 35 #include "mlir/IR/BuiltinTypes.h" 36 37 #include "llvm/ADT/STLExtras.h" 38 #include "llvm/Support/Error.h" 39 #include "llvm/Support/FormatVariadic.h" 40 41 using namespace mlir; 42 43 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst"; 44 45 namespace { 46 47 class GpuToLLVMConversionPass 48 : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> { 49 public: 50 GpuToLLVMConversionPass() = default; 51 52 GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other) 53 : GpuToLLVMConversionPassBase(other) {} 54 55 // Run the dialect converter on the module. 56 void runOnOperation() override; 57 58 private: 59 Option<std::string> gpuBinaryAnnotation{ 60 *this, "gpu-binary-annotation", 61 llvm::cl::desc("Annotation attribute string for GPU binary"), 62 llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())}; 63 }; 64 65 struct FunctionCallBuilder { 66 FunctionCallBuilder(StringRef functionName, Type returnType, 67 ArrayRef<Type> argumentTypes) 68 : functionName(functionName), 69 functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {} 70 LLVM::CallOp create(Location loc, OpBuilder &builder, 71 ArrayRef<Value> arguments) const; 72 73 StringRef functionName; 74 LLVM::LLVMFunctionType functionType; 75 }; 76 77 template <typename OpTy> 78 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> { 79 public: 80 explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 81 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {} 82 83 protected: 84 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc, 85 MemRefType type, MemRefDescriptor desc) const { 86 return type.hasStaticShape() 87 ? ConvertToLLVMPattern::createIndexConstant( 88 rewriter, loc, type.getNumElements()) 89 // For identity maps (verified by caller), the number of 90 // elements is stride[0] * size[0]. 91 : rewriter.create<LLVM::MulOp>(loc, 92 desc.stride(rewriter, loc, 0), 93 desc.size(rewriter, loc, 0)); 94 } 95 96 MLIRContext *context = &this->getTypeConverter()->getContext(); 97 98 Type llvmVoidType = LLVM::LLVMVoidType::get(context); 99 Type llvmPointerType = 100 LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 101 Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType); 102 Type llvmInt8Type = IntegerType::get(context, 8); 103 Type llvmInt32Type = IntegerType::get(context, 32); 104 Type llvmInt64Type = IntegerType::get(context, 64); 105 Type llvmIntPtrType = IntegerType::get( 106 context, this->getTypeConverter()->getPointerBitwidth(0)); 107 108 FunctionCallBuilder moduleLoadCallBuilder = { 109 "mgpuModuleLoad", 110 llvmPointerType /* void *module */, 111 {llvmPointerType /* void *cubin */}}; 112 FunctionCallBuilder moduleUnloadCallBuilder = { 113 "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}}; 114 FunctionCallBuilder moduleGetFunctionCallBuilder = { 115 "mgpuModuleGetFunction", 116 llvmPointerType /* void *function */, 117 { 118 llvmPointerType, /* void *module */ 119 llvmPointerType /* char *name */ 120 }}; 121 FunctionCallBuilder launchKernelCallBuilder = { 122 "mgpuLaunchKernel", 123 llvmVoidType, 124 { 125 llvmPointerType, /* void* f */ 126 llvmIntPtrType, /* intptr_t gridXDim */ 127 llvmIntPtrType, /* intptr_t gridyDim */ 128 llvmIntPtrType, /* intptr_t gridZDim */ 129 llvmIntPtrType, /* intptr_t blockXDim */ 130 llvmIntPtrType, /* intptr_t blockYDim */ 131 llvmIntPtrType, /* intptr_t blockZDim */ 132 llvmInt32Type, /* unsigned int sharedMemBytes */ 133 llvmPointerType, /* void *hstream */ 134 llvmPointerPointerType, /* void **kernelParams */ 135 llvmPointerPointerType /* void **extra */ 136 }}; 137 FunctionCallBuilder streamCreateCallBuilder = { 138 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}}; 139 FunctionCallBuilder streamDestroyCallBuilder = { 140 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}}; 141 FunctionCallBuilder streamSynchronizeCallBuilder = { 142 "mgpuStreamSynchronize", 143 llvmVoidType, 144 {llvmPointerType /* void *stream */}}; 145 FunctionCallBuilder streamWaitEventCallBuilder = { 146 "mgpuStreamWaitEvent", 147 llvmVoidType, 148 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}}; 149 FunctionCallBuilder eventCreateCallBuilder = { 150 "mgpuEventCreate", llvmPointerType /* void *event */, {}}; 151 FunctionCallBuilder eventDestroyCallBuilder = { 152 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}}; 153 FunctionCallBuilder eventSynchronizeCallBuilder = { 154 "mgpuEventSynchronize", 155 llvmVoidType, 156 {llvmPointerType /* void *event */}}; 157 FunctionCallBuilder eventRecordCallBuilder = { 158 "mgpuEventRecord", 159 llvmVoidType, 160 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}}; 161 FunctionCallBuilder hostRegisterCallBuilder = { 162 "mgpuMemHostRegisterMemRef", 163 llvmVoidType, 164 {llvmIntPtrType /* intptr_t rank */, 165 llvmPointerType /* void *memrefDesc */, 166 llvmIntPtrType /* intptr_t elementSizeBytes */}}; 167 FunctionCallBuilder allocCallBuilder = { 168 "mgpuMemAlloc", 169 llvmPointerType /* void * */, 170 {llvmIntPtrType /* intptr_t sizeBytes */, 171 llvmPointerType /* void *stream */}}; 172 FunctionCallBuilder deallocCallBuilder = { 173 "mgpuMemFree", 174 llvmVoidType, 175 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}}; 176 FunctionCallBuilder memcpyCallBuilder = { 177 "mgpuMemcpy", 178 llvmVoidType, 179 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */, 180 llvmIntPtrType /* intptr_t sizeBytes */, 181 llvmPointerType /* void *stream */}}; 182 FunctionCallBuilder memsetCallBuilder = { 183 "mgpuMemset32", 184 llvmVoidType, 185 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */, 186 llvmIntPtrType /* intptr_t sizeBytes */, 187 llvmPointerType /* void *stream */}}; 188 FunctionCallBuilder setDefaultDeviceCallBuilder = { 189 "mgpuSetDefaultDevice", 190 llvmVoidType, 191 {llvmInt32Type /* uint32_t devIndex */}}; 192 }; 193 194 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime 195 /// call. Currently it supports CUDA and ROCm (HIP). 196 class ConvertHostRegisterOpToGpuRuntimeCallPattern 197 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> { 198 public: 199 ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 200 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {} 201 202 private: 203 LogicalResult 204 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, 205 ConversionPatternRewriter &rewriter) const override; 206 }; 207 208 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime 209 /// call. Currently it supports CUDA and ROCm (HIP). 210 class ConvertAllocOpToGpuRuntimeCallPattern 211 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> { 212 public: 213 ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 214 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {} 215 216 private: 217 LogicalResult 218 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor, 219 ConversionPatternRewriter &rewriter) const override; 220 }; 221 222 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime 223 /// call. Currently it supports CUDA and ROCm (HIP). 224 class ConvertDeallocOpToGpuRuntimeCallPattern 225 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> { 226 public: 227 ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 228 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {} 229 230 private: 231 LogicalResult 232 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor, 233 ConversionPatternRewriter &rewriter) const override; 234 }; 235 236 class ConvertAsyncYieldToGpuRuntimeCallPattern 237 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> { 238 public: 239 ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 240 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {} 241 242 private: 243 LogicalResult 244 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor, 245 ConversionPatternRewriter &rewriter) const override; 246 }; 247 248 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime 249 /// call. Currently it supports CUDA and ROCm (HIP). 250 class ConvertWaitOpToGpuRuntimeCallPattern 251 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 252 public: 253 ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 254 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 255 256 private: 257 LogicalResult 258 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, 259 ConversionPatternRewriter &rewriter) const override; 260 }; 261 262 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime 263 /// call. Currently it supports CUDA and ROCm (HIP). 264 class ConvertWaitAsyncOpToGpuRuntimeCallPattern 265 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 266 public: 267 ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 268 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 269 270 private: 271 LogicalResult 272 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor, 273 ConversionPatternRewriter &rewriter) const override; 274 }; 275 276 /// A rewrite patter to convert gpu.launch_func operations into a sequence of 277 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP). 278 /// 279 /// In essence, a gpu.launch_func operations gets compiled into the following 280 /// sequence of runtime calls: 281 /// 282 /// * moduleLoad -- loads the module given the cubin / hsaco data 283 /// * moduleGetFunction -- gets a handle to the actual kernel function 284 /// * getStreamHelper -- initializes a new compute stream on GPU 285 /// * launchKernel -- launches the kernel on a stream 286 /// * streamSynchronize -- waits for operations on the stream to finish 287 /// 288 /// Intermediate data structures are allocated on the stack. 289 class ConvertLaunchFuncOpToGpuRuntimeCallPattern 290 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> { 291 public: 292 ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter, 293 StringRef gpuBinaryAnnotation) 294 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter), 295 gpuBinaryAnnotation(gpuBinaryAnnotation) {} 296 297 private: 298 Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 299 OpBuilder &builder) const; 300 Value generateKernelNameConstant(StringRef moduleName, StringRef name, 301 Location loc, OpBuilder &builder) const; 302 303 LogicalResult 304 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 305 ConversionPatternRewriter &rewriter) const override; 306 307 llvm::SmallString<32> gpuBinaryAnnotation; 308 }; 309 310 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> { 311 using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern; 312 313 LogicalResult matchAndRewrite(gpu::GPUModuleOp op, 314 PatternRewriter &rewriter) const override { 315 // GPU kernel modules are no longer necessary since we have a global 316 // constant with the CUBIN, or HSACO data. 317 rewriter.eraseOp(op); 318 return success(); 319 } 320 }; 321 322 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime 323 /// call. Currently it supports CUDA and ROCm (HIP). 324 class ConvertMemcpyOpToGpuRuntimeCallPattern 325 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> { 326 public: 327 ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 328 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {} 329 330 private: 331 LogicalResult 332 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, 333 ConversionPatternRewriter &rewriter) const override; 334 }; 335 336 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime 337 /// call. Currently it supports CUDA and ROCm (HIP). 338 class ConvertMemsetOpToGpuRuntimeCallPattern 339 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> { 340 public: 341 ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 342 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {} 343 344 private: 345 LogicalResult 346 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor, 347 ConversionPatternRewriter &rewriter) const override; 348 }; 349 350 /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call. 351 /// Currently supports CUDA and ROCm (HIP) 352 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern 353 : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> { 354 public: 355 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern( 356 LLVMTypeConverter &typeConverter) 357 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>( 358 typeConverter) {} 359 360 LogicalResult 361 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, 362 ConversionPatternRewriter &rewriter) const override; 363 }; 364 } // namespace 365 366 void GpuToLLVMConversionPass::runOnOperation() { 367 LLVMTypeConverter converter(&getContext()); 368 RewritePatternSet patterns(&getContext()); 369 LLVMConversionTarget target(getContext()); 370 371 target.addIllegalDialect<gpu::GPUDialect>(); 372 373 mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns); 374 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); 375 populateVectorToLLVMConversionPatterns(converter, patterns); 376 populateMemRefToLLVMConversionPatterns(converter, patterns); 377 populateFuncToLLVMConversionPatterns(converter, patterns); 378 populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, 379 target); 380 populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation); 381 382 if (failed( 383 applyPartialConversion(getOperation(), target, std::move(patterns)))) 384 signalPassFailure(); 385 } 386 387 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, 388 ArrayRef<Value> arguments) const { 389 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>(); 390 auto function = [&] { 391 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) 392 return function; 393 return OpBuilder::atBlockEnd(module.getBody()) 394 .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); 395 }(); 396 return builder.create<LLVM::CallOp>(loc, function, arguments); 397 } 398 399 // Returns whether all operands are of LLVM type. 400 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, 401 ConversionPatternRewriter &rewriter) { 402 if (!llvm::all_of(operands, [](Value value) { 403 return LLVM::isCompatibleType(value.getType()); 404 })) 405 return rewriter.notifyMatchFailure( 406 op, "Cannot convert if operands aren't of LLVM type."); 407 return success(); 408 } 409 410 static LogicalResult 411 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, 412 gpu::AsyncOpInterface op) { 413 if (op.getAsyncDependencies().size() != 1) 414 return rewriter.notifyMatchFailure( 415 op, "Can only convert with exactly one async dependency."); 416 417 if (!op.getAsyncToken()) 418 return rewriter.notifyMatchFailure(op, "Can convert only async version."); 419 420 return success(); 421 } 422 423 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( 424 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor, 425 ConversionPatternRewriter &rewriter) const { 426 auto *op = hostRegisterOp.getOperation(); 427 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter))) 428 return failure(); 429 430 Location loc = op->getLoc(); 431 432 auto memRefType = hostRegisterOp.value().getType(); 433 auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType(); 434 auto elementSize = getSizeInBytes(loc, elementType, rewriter); 435 436 auto arguments = getTypeConverter()->promoteOperands( 437 loc, op->getOperands(), adaptor.getOperands(), rewriter); 438 arguments.push_back(elementSize); 439 hostRegisterCallBuilder.create(loc, rewriter, arguments); 440 441 rewriter.eraseOp(op); 442 return success(); 443 } 444 445 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( 446 gpu::AllocOp allocOp, OpAdaptor adaptor, 447 ConversionPatternRewriter &rewriter) const { 448 MemRefType memRefType = allocOp.getType(); 449 450 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) || 451 !isConvertibleAndHasIdentityMaps(memRefType) || 452 failed(isAsyncWithOneDependency(rewriter, allocOp))) 453 return failure(); 454 455 auto loc = allocOp.getLoc(); 456 457 // Get shape of the memref as values: static sizes are constant 458 // values and dynamic sizes are passed to 'alloc' as operands. 459 SmallVector<Value, 4> shape; 460 SmallVector<Value, 4> strides; 461 Value sizeBytes; 462 getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter, 463 shape, strides, sizeBytes); 464 465 // Allocate the underlying buffer and store a pointer to it in the MemRef 466 // descriptor. 467 Type elementPtrType = this->getElementPtrType(memRefType); 468 auto stream = adaptor.asyncDependencies().front(); 469 Value allocatedPtr = 470 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0); 471 allocatedPtr = 472 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr); 473 474 // No alignment. 475 Value alignedPtr = allocatedPtr; 476 477 // Create the MemRef descriptor. 478 auto memRefDescriptor = this->createMemRefDescriptor( 479 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); 480 481 rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); 482 483 return success(); 484 } 485 486 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( 487 gpu::DeallocOp deallocOp, OpAdaptor adaptor, 488 ConversionPatternRewriter &rewriter) const { 489 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) || 490 failed(isAsyncWithOneDependency(rewriter, deallocOp))) 491 return failure(); 492 493 Location loc = deallocOp.getLoc(); 494 495 Value pointer = 496 MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc); 497 auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer); 498 Value stream = adaptor.asyncDependencies().front(); 499 deallocCallBuilder.create(loc, rewriter, {casted, stream}); 500 501 rewriter.replaceOp(deallocOp, {stream}); 502 return success(); 503 } 504 505 static bool isGpuAsyncTokenType(Value value) { 506 return value.getType().isa<gpu::AsyncTokenType>(); 507 } 508 509 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The 510 // !gpu.async.token are lowered to stream within the async.execute region, but 511 // are passed as events between them. For each !gpu.async.token operand, we 512 // create an event and record it on the stream. 513 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( 514 async::YieldOp yieldOp, OpAdaptor adaptor, 515 ConversionPatternRewriter &rewriter) const { 516 if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType)) 517 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand"); 518 519 Location loc = yieldOp.getLoc(); 520 SmallVector<Value, 4> newOperands(adaptor.getOperands()); 521 llvm::SmallDenseSet<Value> streams; 522 for (auto &operand : yieldOp->getOpOperands()) { 523 if (!isGpuAsyncTokenType(operand.get())) 524 continue; 525 auto idx = operand.getOperandNumber(); 526 auto stream = adaptor.getOperands()[idx]; 527 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 528 eventRecordCallBuilder.create(loc, rewriter, {event, stream}); 529 newOperands[idx] = event; 530 streams.insert(stream); 531 } 532 for (auto stream : streams) 533 streamDestroyCallBuilder.create(loc, rewriter, {stream}); 534 535 rewriter.updateRootInPlace(yieldOp, 536 [&] { yieldOp->setOperands(newOperands); }); 537 return success(); 538 } 539 540 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`. 541 static bool isDefinedByCallTo(Value value, StringRef functionName) { 542 assert(value.getType().isa<LLVM::LLVMPointerType>()); 543 if (auto defOp = value.getDefiningOp<LLVM::CallOp>()) 544 return defOp.getCallee()->equals(functionName); 545 return false; 546 } 547 548 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host 549 // with the stream/event operands. The operands are destroyed. That is, it 550 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a 551 // runtime error. Eventually, we should guarantee this property. 552 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( 553 gpu::WaitOp waitOp, OpAdaptor adaptor, 554 ConversionPatternRewriter &rewriter) const { 555 if (waitOp.asyncToken()) 556 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op."); 557 558 Location loc = waitOp.getLoc(); 559 560 for (auto operand : adaptor.getOperands()) { 561 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 562 // The converted operand's definition created a stream. 563 streamSynchronizeCallBuilder.create(loc, rewriter, {operand}); 564 streamDestroyCallBuilder.create(loc, rewriter, {operand}); 565 } else { 566 // Otherwise the converted operand is an event. This assumes that we use 567 // events in control flow code as well. 568 eventSynchronizeCallBuilder.create(loc, rewriter, {operand}); 569 eventDestroyCallBuilder.create(loc, rewriter, {operand}); 570 } 571 } 572 573 rewriter.eraseOp(waitOp); 574 return success(); 575 } 576 577 // Converts `gpu.wait async` to runtime calls. The converted op creates a new 578 // stream that is synchronized with stream/event operands. The operands are 579 // destroyed. That is, it assumes that it is not used afterwards or elsewhere. 580 // Otherwise we will get a runtime error. Eventually, we should guarantee this 581 // property. 582 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( 583 gpu::WaitOp waitOp, OpAdaptor adaptor, 584 ConversionPatternRewriter &rewriter) const { 585 if (!waitOp.asyncToken()) 586 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op."); 587 588 Location loc = waitOp.getLoc(); 589 590 auto insertionPoint = rewriter.saveInsertionPoint(); 591 SmallVector<Value, 1> events; 592 for (auto pair : 593 llvm::zip(waitOp.asyncDependencies(), adaptor.getOperands())) { 594 auto operand = std::get<1>(pair); 595 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 596 // The converted operand's definition created a stream. Insert an event 597 // into the stream just after the last use of the original token operand. 598 auto *defOp = std::get<0>(pair).getDefiningOp(); 599 rewriter.setInsertionPointAfter(defOp); 600 auto event = 601 eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 602 eventRecordCallBuilder.create(loc, rewriter, {event, operand}); 603 events.push_back(event); 604 } else { 605 // Otherwise the converted operand is an event. This assumes that we use 606 // events in control flow code as well. 607 events.push_back(operand); 608 } 609 } 610 rewriter.restoreInsertionPoint(insertionPoint); 611 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 612 for (auto event : events) 613 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event}); 614 for (auto event : events) 615 eventDestroyCallBuilder.create(loc, rewriter, {event}); 616 rewriter.replaceOp(waitOp, {stream}); 617 618 return success(); 619 } 620 621 // Creates a struct containing all kernel parameters on the stack and returns 622 // an array of type-erased pointers to the fields of the struct. The array can 623 // then be passed to the CUDA / ROCm (HIP) kernel launch calls. 624 // The generated code is essentially as follows: 625 // 626 // %struct = alloca(sizeof(struct { Parameters... })) 627 // %array = alloca(NumParameters * sizeof(void *)) 628 // for (i : [0, NumParameters)) 629 // %fieldPtr = llvm.getelementptr %struct[0, i] 630 // llvm.store parameters[i], %fieldPtr 631 // %elementPtr = llvm.getelementptr %array[i] 632 // llvm.store %fieldPtr, %elementPtr 633 // return %array 634 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( 635 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const { 636 auto loc = launchOp.getLoc(); 637 auto numKernelOperands = launchOp.getNumKernelOperands(); 638 auto arguments = getTypeConverter()->promoteOperands( 639 loc, launchOp.getOperands().take_back(numKernelOperands), 640 adaptor.getOperands().take_back(numKernelOperands), builder); 641 auto numArguments = arguments.size(); 642 SmallVector<Type, 4> argumentTypes; 643 argumentTypes.reserve(numArguments); 644 for (auto argument : arguments) 645 argumentTypes.push_back(argument.getType()); 646 auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(), 647 argumentTypes); 648 auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 649 builder.getI32IntegerAttr(1)); 650 auto structPtr = builder.create<LLVM::AllocaOp>( 651 loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0); 652 auto arraySize = builder.create<LLVM::ConstantOp>( 653 loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments)); 654 auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType, 655 arraySize, /*alignment=*/0); 656 auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 657 builder.getI32IntegerAttr(0)); 658 for (const auto &en : llvm::enumerate(arguments)) { 659 auto index = builder.create<LLVM::ConstantOp>( 660 loc, llvmInt32Type, builder.getI32IntegerAttr(en.index())); 661 auto fieldPtr = builder.create<LLVM::GEPOp>( 662 loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr, 663 ArrayRef<Value>{zero, index.getResult()}); 664 builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr); 665 auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType, 666 arrayPtr, index.getResult()); 667 auto casted = 668 builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr); 669 builder.create<LLVM::StoreOp>(loc, casted, elementPtr); 670 } 671 return arrayPtr; 672 } 673 674 // Generates an LLVM IR dialect global that contains the name of the given 675 // kernel function as a C string, and returns a pointer to its beginning. 676 // The code is essentially: 677 // 678 // llvm.global constant @kernel_name("function_name\00") 679 // func(...) { 680 // %0 = llvm.addressof @kernel_name 681 // %1 = llvm.constant (0 : index) 682 // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*"> 683 // } 684 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant( 685 StringRef moduleName, StringRef name, Location loc, 686 OpBuilder &builder) const { 687 // Make sure the trailing zero is included in the constant. 688 std::vector<char> kernelName(name.begin(), name.end()); 689 kernelName.push_back('\0'); 690 691 std::string globalName = 692 std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name)); 693 return LLVM::createGlobalString( 694 loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()), 695 LLVM::Linkage::Internal); 696 } 697 698 // Emits LLVM IR to launch a kernel function. Expects the module that contains 699 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a 700 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR. 701 // 702 // %0 = call %binarygetter 703 // %1 = call %moduleLoad(%0) 704 // %2 = <see generateKernelNameConstant> 705 // %3 = call %moduleGetFunction(%1, %2) 706 // %4 = call %streamCreate() 707 // %5 = <see generateParamsArray> 708 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr) 709 // call %streamSynchronize(%4) 710 // call %streamDestroy(%4) 711 // call %moduleUnload(%1) 712 // 713 // If the op is async, the stream corresponds to the (single) async dependency 714 // as well as the async token the op produces. 715 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( 716 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, 717 ConversionPatternRewriter &rewriter) const { 718 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter))) 719 return failure(); 720 721 if (launchOp.asyncDependencies().size() > 1) 722 return rewriter.notifyMatchFailure( 723 launchOp, "Cannot convert with more than one async dependency."); 724 725 // Fail when the synchronous version of the op has async dependencies. The 726 // lowering destroys the stream, and we do not want to check that there is no 727 // use of the stream after this op. 728 if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty()) 729 return rewriter.notifyMatchFailure( 730 launchOp, "Cannot convert non-async op with async dependencies."); 731 732 Location loc = launchOp.getLoc(); 733 734 // Create an LLVM global with CUBIN extracted from the kernel annotation and 735 // obtain a pointer to the first byte in it. 736 auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>( 737 launchOp, launchOp.getKernelModuleName()); 738 assert(kernelModule && "expected a kernel module"); 739 740 auto binaryAttr = 741 kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation); 742 if (!binaryAttr) { 743 kernelModule.emitOpError() 744 << "missing " << gpuBinaryAnnotation << " attribute"; 745 return failure(); 746 } 747 748 SmallString<128> nameBuffer(kernelModule.getName()); 749 nameBuffer.append(kGpuBinaryStorageSuffix); 750 Value data = 751 LLVM::createGlobalString(loc, rewriter, nameBuffer.str(), 752 binaryAttr.getValue(), LLVM::Linkage::Internal); 753 754 auto module = moduleLoadCallBuilder.create(loc, rewriter, data); 755 // Get the function from the module. The name corresponds to the name of 756 // the kernel function. 757 auto kernelName = generateKernelNameConstant( 758 launchOp.getKernelModuleName().getValue(), 759 launchOp.getKernelName().getValue(), loc, rewriter); 760 auto function = moduleGetFunctionCallBuilder.create( 761 loc, rewriter, {module.getResult(0), kernelName}); 762 auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 763 rewriter.getI32IntegerAttr(0)); 764 Value stream = 765 adaptor.asyncDependencies().empty() 766 ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0) 767 : adaptor.asyncDependencies().front(); 768 // Create array of pointers to kernel arguments. 769 auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter); 770 auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType); 771 Value dynamicSharedMemorySize = launchOp.dynamicSharedMemorySize() 772 ? launchOp.dynamicSharedMemorySize() 773 : zero; 774 launchKernelCallBuilder.create( 775 loc, rewriter, 776 {function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(), 777 adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(), 778 adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams, 779 /*extra=*/nullpointer}); 780 781 if (launchOp.asyncToken()) { 782 // Async launch: make dependent ops use the same stream. 783 rewriter.replaceOp(launchOp, {stream}); 784 } else { 785 // Synchronize with host and destroy stream. This must be the stream created 786 // above (with no other uses) because we check that the synchronous version 787 // does not have any async dependencies. 788 streamSynchronizeCallBuilder.create(loc, rewriter, stream); 789 streamDestroyCallBuilder.create(loc, rewriter, stream); 790 rewriter.eraseOp(launchOp); 791 } 792 moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0)); 793 794 return success(); 795 } 796 797 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( 798 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor, 799 ConversionPatternRewriter &rewriter) const { 800 auto memRefType = memcpyOp.src().getType().cast<MemRefType>(); 801 802 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) || 803 !isConvertibleAndHasIdentityMaps(memRefType) || 804 failed(isAsyncWithOneDependency(rewriter, memcpyOp))) 805 return failure(); 806 807 auto loc = memcpyOp.getLoc(); 808 809 MemRefDescriptor srcDesc(adaptor.src()); 810 Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc); 811 812 Type elementPtrType = getElementPtrType(memRefType); 813 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType); 814 Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr, 815 ArrayRef<Value>{numElements}); 816 auto sizeBytes = 817 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); 818 819 auto src = rewriter.create<LLVM::BitcastOp>( 820 loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc)); 821 auto dst = rewriter.create<LLVM::BitcastOp>( 822 loc, llvmPointerType, 823 MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc)); 824 825 auto stream = adaptor.asyncDependencies().front(); 826 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream}); 827 828 rewriter.replaceOp(memcpyOp, {stream}); 829 830 return success(); 831 } 832 833 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite( 834 gpu::MemsetOp memsetOp, OpAdaptor adaptor, 835 ConversionPatternRewriter &rewriter) const { 836 auto memRefType = memsetOp.dst().getType().cast<MemRefType>(); 837 838 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) || 839 !isConvertibleAndHasIdentityMaps(memRefType) || 840 failed(isAsyncWithOneDependency(rewriter, memsetOp))) 841 return failure(); 842 843 auto loc = memsetOp.getLoc(); 844 845 Type valueType = adaptor.value().getType(); 846 if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) { 847 return rewriter.notifyMatchFailure(memsetOp, 848 "value must be a 32 bit scalar"); 849 } 850 851 MemRefDescriptor dstDesc(adaptor.dst()); 852 Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc); 853 854 auto value = 855 rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value()); 856 auto dst = rewriter.create<LLVM::BitcastOp>( 857 loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc)); 858 859 auto stream = adaptor.asyncDependencies().front(); 860 memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream}); 861 862 rewriter.replaceOp(memsetOp, {stream}); 863 return success(); 864 } 865 866 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite( 867 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor, 868 ConversionPatternRewriter &rewriter) const { 869 Location loc = op.getLoc(); 870 setDefaultDeviceCallBuilder.create(loc, rewriter, {adaptor.devIndex()}); 871 rewriter.replaceOp(op, {}); 872 return success(); 873 } 874 875 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 876 mlir::createGpuToLLVMConversionPass() { 877 return std::make_unique<GpuToLLVMConversionPass>(); 878 } 879 880 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter, 881 RewritePatternSet &patterns, 882 StringRef gpuBinaryAnnotation) { 883 converter.addConversion( 884 [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type { 885 return LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 886 }); 887 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern, 888 ConvertDeallocOpToGpuRuntimeCallPattern, 889 ConvertHostRegisterOpToGpuRuntimeCallPattern, 890 ConvertMemcpyOpToGpuRuntimeCallPattern, 891 ConvertMemsetOpToGpuRuntimeCallPattern, 892 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern, 893 ConvertWaitAsyncOpToGpuRuntimeCallPattern, 894 ConvertWaitOpToGpuRuntimeCallPattern, 895 ConvertAsyncYieldToGpuRuntimeCallPattern>(converter); 896 patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter, 897 gpuBinaryAnnotation); 898 patterns.add<EraseGpuModuleOpPattern>(&converter.getContext()); 899 } 900