1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert gpu.launch_func op into a sequence of 10 // GPU runtime calls. As most of GPU runtimes does not have a stable published 11 // ABI, this pass uses a slim runtime layer that builds on top of the public 12 // API from GPU runtime headers. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 17 18 #include "../PassDetail.h" 19 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 21 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 22 #include "mlir/Dialect/Async/IR/Async.h" 23 #include "mlir/Dialect/GPU/GPUDialect.h" 24 #include "mlir/Dialect/GPU/Passes.h" 25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 26 #include "mlir/IR/Attributes.h" 27 #include "mlir/IR/Builders.h" 28 #include "mlir/IR/BuiltinOps.h" 29 #include "mlir/IR/BuiltinTypes.h" 30 31 #include "llvm/ADT/STLExtras.h" 32 #include "llvm/Support/Error.h" 33 #include "llvm/Support/FormatVariadic.h" 34 35 using namespace mlir; 36 37 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst"; 38 39 namespace { 40 41 class GpuToLLVMConversionPass 42 : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> { 43 public: 44 GpuToLLVMConversionPass() = default; 45 46 GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other) 47 : GpuToLLVMConversionPassBase(other) {} 48 49 // Run the dialect converter on the module. 50 void runOnOperation() override; 51 52 private: 53 Option<std::string> gpuBinaryAnnotation{ 54 *this, "gpu-binary-annotation", 55 llvm::cl::desc("Annotation attribute string for GPU binary"), 56 llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())}; 57 }; 58 59 struct FunctionCallBuilder { 60 FunctionCallBuilder(StringRef functionName, Type returnType, 61 ArrayRef<Type> argumentTypes) 62 : functionName(functionName), 63 functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {} 64 LLVM::CallOp create(Location loc, OpBuilder &builder, 65 ArrayRef<Value> arguments) const; 66 67 StringRef functionName; 68 LLVM::LLVMFunctionType functionType; 69 }; 70 71 template <typename OpTy> 72 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> { 73 public: 74 explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 75 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {} 76 77 protected: 78 MLIRContext *context = &this->getTypeConverter()->getContext(); 79 80 Type llvmVoidType = LLVM::LLVMVoidType::get(context); 81 Type llvmPointerType = 82 LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 83 Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType); 84 Type llvmInt8Type = IntegerType::get(context, 8); 85 Type llvmInt32Type = IntegerType::get(context, 32); 86 Type llvmInt64Type = IntegerType::get(context, 64); 87 Type llvmIntPtrType = IntegerType::get( 88 context, this->getTypeConverter()->getPointerBitwidth(0)); 89 90 FunctionCallBuilder moduleLoadCallBuilder = { 91 "mgpuModuleLoad", 92 llvmPointerType /* void *module */, 93 {llvmPointerType /* void *cubin */}}; 94 FunctionCallBuilder moduleUnloadCallBuilder = { 95 "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}}; 96 FunctionCallBuilder moduleGetFunctionCallBuilder = { 97 "mgpuModuleGetFunction", 98 llvmPointerType /* void *function */, 99 { 100 llvmPointerType, /* void *module */ 101 llvmPointerType /* char *name */ 102 }}; 103 FunctionCallBuilder launchKernelCallBuilder = { 104 "mgpuLaunchKernel", 105 llvmVoidType, 106 { 107 llvmPointerType, /* void* f */ 108 llvmIntPtrType, /* intptr_t gridXDim */ 109 llvmIntPtrType, /* intptr_t gridyDim */ 110 llvmIntPtrType, /* intptr_t gridZDim */ 111 llvmIntPtrType, /* intptr_t blockXDim */ 112 llvmIntPtrType, /* intptr_t blockYDim */ 113 llvmIntPtrType, /* intptr_t blockZDim */ 114 llvmInt32Type, /* unsigned int sharedMemBytes */ 115 llvmPointerType, /* void *hstream */ 116 llvmPointerPointerType, /* void **kernelParams */ 117 llvmPointerPointerType /* void **extra */ 118 }}; 119 FunctionCallBuilder streamCreateCallBuilder = { 120 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}}; 121 FunctionCallBuilder streamDestroyCallBuilder = { 122 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}}; 123 FunctionCallBuilder streamSynchronizeCallBuilder = { 124 "mgpuStreamSynchronize", 125 llvmVoidType, 126 {llvmPointerType /* void *stream */}}; 127 FunctionCallBuilder streamWaitEventCallBuilder = { 128 "mgpuStreamWaitEvent", 129 llvmVoidType, 130 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}}; 131 FunctionCallBuilder eventCreateCallBuilder = { 132 "mgpuEventCreate", llvmPointerType /* void *event */, {}}; 133 FunctionCallBuilder eventDestroyCallBuilder = { 134 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}}; 135 FunctionCallBuilder eventSynchronizeCallBuilder = { 136 "mgpuEventSynchronize", 137 llvmVoidType, 138 {llvmPointerType /* void *event */}}; 139 FunctionCallBuilder eventRecordCallBuilder = { 140 "mgpuEventRecord", 141 llvmVoidType, 142 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}}; 143 FunctionCallBuilder hostRegisterCallBuilder = { 144 "mgpuMemHostRegisterMemRef", 145 llvmVoidType, 146 {llvmIntPtrType /* intptr_t rank */, 147 llvmPointerType /* void *memrefDesc */, 148 llvmIntPtrType /* intptr_t elementSizeBytes */}}; 149 FunctionCallBuilder allocCallBuilder = { 150 "mgpuMemAlloc", 151 llvmPointerType /* void * */, 152 {llvmIntPtrType /* intptr_t sizeBytes */, 153 llvmPointerType /* void *stream */}}; 154 FunctionCallBuilder deallocCallBuilder = { 155 "mgpuMemFree", 156 llvmVoidType, 157 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}}; 158 FunctionCallBuilder memcpyCallBuilder = { 159 "mgpuMemcpy", 160 llvmVoidType, 161 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */, 162 llvmIntPtrType /* intptr_t sizeBytes */, 163 llvmPointerType /* void *stream */}}; 164 }; 165 166 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime 167 /// call. Currently it supports CUDA and ROCm (HIP). 168 class ConvertHostRegisterOpToGpuRuntimeCallPattern 169 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> { 170 public: 171 ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 172 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {} 173 174 private: 175 LogicalResult 176 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands, 177 ConversionPatternRewriter &rewriter) const override; 178 }; 179 180 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime 181 /// call. Currently it supports CUDA and ROCm (HIP). 182 class ConvertAllocOpToGpuRuntimeCallPattern 183 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> { 184 public: 185 ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 186 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {} 187 188 private: 189 LogicalResult 190 matchAndRewrite(gpu::AllocOp allocOp, ArrayRef<Value> operands, 191 ConversionPatternRewriter &rewriter) const override; 192 }; 193 194 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime 195 /// call. Currently it supports CUDA and ROCm (HIP). 196 class ConvertDeallocOpToGpuRuntimeCallPattern 197 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> { 198 public: 199 ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 200 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {} 201 202 private: 203 LogicalResult 204 matchAndRewrite(gpu::DeallocOp deallocOp, ArrayRef<Value> operands, 205 ConversionPatternRewriter &rewriter) const override; 206 }; 207 208 class ConvertAsyncYieldToGpuRuntimeCallPattern 209 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> { 210 public: 211 ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 212 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {} 213 214 private: 215 LogicalResult 216 matchAndRewrite(async::YieldOp yieldOp, ArrayRef<Value> operands, 217 ConversionPatternRewriter &rewriter) const override; 218 }; 219 220 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime 221 /// call. Currently it supports CUDA and ROCm (HIP). 222 class ConvertWaitOpToGpuRuntimeCallPattern 223 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 224 public: 225 ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 226 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 227 228 private: 229 LogicalResult 230 matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands, 231 ConversionPatternRewriter &rewriter) const override; 232 }; 233 234 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime 235 /// call. Currently it supports CUDA and ROCm (HIP). 236 class ConvertWaitAsyncOpToGpuRuntimeCallPattern 237 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 238 public: 239 ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 240 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 241 242 private: 243 LogicalResult 244 matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands, 245 ConversionPatternRewriter &rewriter) const override; 246 }; 247 248 /// A rewrite patter to convert gpu.launch_func operations into a sequence of 249 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP). 250 /// 251 /// In essence, a gpu.launch_func operations gets compiled into the following 252 /// sequence of runtime calls: 253 /// 254 /// * moduleLoad -- loads the module given the cubin / hsaco data 255 /// * moduleGetFunction -- gets a handle to the actual kernel function 256 /// * getStreamHelper -- initializes a new compute stream on GPU 257 /// * launchKernel -- launches the kernel on a stream 258 /// * streamSynchronize -- waits for operations on the stream to finish 259 /// 260 /// Intermediate data structures are allocated on the stack. 261 class ConvertLaunchFuncOpToGpuRuntimeCallPattern 262 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> { 263 public: 264 ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter, 265 StringRef gpuBinaryAnnotation) 266 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter), 267 gpuBinaryAnnotation(gpuBinaryAnnotation) {} 268 269 private: 270 Value generateParamsArray(gpu::LaunchFuncOp launchOp, 271 ArrayRef<Value> operands, OpBuilder &builder) const; 272 Value generateKernelNameConstant(StringRef moduleName, StringRef name, 273 Location loc, OpBuilder &builder) const; 274 275 LogicalResult 276 matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands, 277 ConversionPatternRewriter &rewriter) const override; 278 279 llvm::SmallString<32> gpuBinaryAnnotation; 280 }; 281 282 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> { 283 using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern; 284 285 LogicalResult matchAndRewrite(gpu::GPUModuleOp op, 286 PatternRewriter &rewriter) const override { 287 // GPU kernel modules are no longer necessary since we have a global 288 // constant with the CUBIN, or HSACO data. 289 rewriter.eraseOp(op); 290 return success(); 291 } 292 }; 293 294 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime 295 /// call. Currently it supports CUDA and ROCm (HIP). 296 class ConvertMemcpyOpToGpuRuntimeCallPattern 297 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> { 298 public: 299 ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 300 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {} 301 302 private: 303 LogicalResult 304 matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands, 305 ConversionPatternRewriter &rewriter) const override; 306 }; 307 } // namespace 308 309 void GpuToLLVMConversionPass::runOnOperation() { 310 LLVMTypeConverter converter(&getContext()); 311 RewritePatternSet patterns(&getContext()); 312 LLVMConversionTarget target(getContext()); 313 314 target.addIllegalDialect<gpu::GPUDialect>(); 315 316 populateVectorToLLVMConversionPatterns(converter, patterns); 317 populateStdToLLVMConversionPatterns(converter, patterns); 318 populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, 319 target); 320 321 converter.addConversion( 322 [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type { 323 return LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 324 }); 325 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern, 326 ConvertDeallocOpToGpuRuntimeCallPattern, 327 ConvertHostRegisterOpToGpuRuntimeCallPattern, 328 ConvertMemcpyOpToGpuRuntimeCallPattern, 329 ConvertWaitAsyncOpToGpuRuntimeCallPattern, 330 ConvertWaitOpToGpuRuntimeCallPattern, 331 ConvertAsyncYieldToGpuRuntimeCallPattern>(converter); 332 patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter, 333 gpuBinaryAnnotation); 334 patterns.add<EraseGpuModuleOpPattern>(&converter.getContext()); 335 336 if (failed( 337 applyPartialConversion(getOperation(), target, std::move(patterns)))) 338 signalPassFailure(); 339 } 340 341 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, 342 ArrayRef<Value> arguments) const { 343 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>(); 344 auto function = [&] { 345 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) 346 return function; 347 return OpBuilder::atBlockEnd(module.getBody()) 348 .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); 349 }(); 350 return builder.create<LLVM::CallOp>( 351 loc, const_cast<LLVM::LLVMFunctionType &>(functionType).getReturnType(), 352 builder.getSymbolRefAttr(function), arguments); 353 } 354 355 // Returns whether all operands are of LLVM type. 356 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, 357 ConversionPatternRewriter &rewriter) { 358 if (!llvm::all_of(operands, [](Value value) { 359 return LLVM::isCompatibleType(value.getType()); 360 })) 361 return rewriter.notifyMatchFailure( 362 op, "Cannot convert if operands aren't of LLVM type."); 363 return success(); 364 } 365 366 static LogicalResult 367 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, 368 gpu::AsyncOpInterface op) { 369 if (op.getAsyncDependencies().size() != 1) 370 return rewriter.notifyMatchFailure( 371 op, "Can only convert with exactly one async dependency."); 372 373 if (!op.getAsyncToken()) 374 return rewriter.notifyMatchFailure(op, "Can convert only async version."); 375 376 return success(); 377 } 378 379 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( 380 gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands, 381 ConversionPatternRewriter &rewriter) const { 382 auto *op = hostRegisterOp.getOperation(); 383 if (failed(areAllLLVMTypes(op, operands, rewriter))) 384 return failure(); 385 386 Location loc = op->getLoc(); 387 388 auto memRefType = hostRegisterOp.value().getType(); 389 auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType(); 390 auto elementSize = getSizeInBytes(loc, elementType, rewriter); 391 392 auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(), 393 operands, rewriter); 394 arguments.push_back(elementSize); 395 hostRegisterCallBuilder.create(loc, rewriter, arguments); 396 397 rewriter.eraseOp(op); 398 return success(); 399 } 400 401 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( 402 gpu::AllocOp allocOp, ArrayRef<Value> operands, 403 ConversionPatternRewriter &rewriter) const { 404 MemRefType memRefType = allocOp.getType(); 405 406 if (failed(areAllLLVMTypes(allocOp, operands, rewriter)) || 407 !isConvertibleAndHasIdentityMaps(memRefType) || 408 failed(isAsyncWithOneDependency(rewriter, allocOp))) 409 return failure(); 410 411 auto loc = allocOp.getLoc(); 412 auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary()); 413 414 // Get shape of the memref as values: static sizes are constant 415 // values and dynamic sizes are passed to 'alloc' as operands. 416 SmallVector<Value, 4> shape; 417 SmallVector<Value, 4> strides; 418 Value sizeBytes; 419 getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter, 420 shape, strides, sizeBytes); 421 422 // Allocate the underlying buffer and store a pointer to it in the MemRef 423 // descriptor. 424 Type elementPtrType = this->getElementPtrType(memRefType); 425 auto stream = adaptor.asyncDependencies().front(); 426 Value allocatedPtr = 427 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0); 428 allocatedPtr = 429 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr); 430 431 // No alignment. 432 Value alignedPtr = allocatedPtr; 433 434 // Create the MemRef descriptor. 435 auto memRefDescriptor = this->createMemRefDescriptor( 436 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); 437 438 rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); 439 440 return success(); 441 } 442 443 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( 444 gpu::DeallocOp deallocOp, ArrayRef<Value> operands, 445 ConversionPatternRewriter &rewriter) const { 446 if (failed(areAllLLVMTypes(deallocOp, operands, rewriter)) || 447 failed(isAsyncWithOneDependency(rewriter, deallocOp))) 448 return failure(); 449 450 Location loc = deallocOp.getLoc(); 451 452 auto adaptor = 453 gpu::DeallocOpAdaptor(operands, deallocOp->getAttrDictionary()); 454 Value pointer = 455 MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc); 456 auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer); 457 Value stream = adaptor.asyncDependencies().front(); 458 deallocCallBuilder.create(loc, rewriter, {casted, stream}); 459 460 rewriter.replaceOp(deallocOp, {stream}); 461 return success(); 462 } 463 464 static bool isGpuAsyncTokenType(Value value) { 465 return value.getType().isa<gpu::AsyncTokenType>(); 466 } 467 468 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The 469 // !gpu.async.token are lowered to stream within the async.execute region, but 470 // are passed as events between them. For each !gpu.async.token operand, we 471 // create an event and record it on the stream. 472 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( 473 async::YieldOp yieldOp, ArrayRef<Value> operands, 474 ConversionPatternRewriter &rewriter) const { 475 if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType)) 476 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand"); 477 478 Location loc = yieldOp.getLoc(); 479 SmallVector<Value, 4> newOperands(operands.begin(), operands.end()); 480 llvm::SmallDenseSet<Value> streams; 481 for (auto &operand : yieldOp->getOpOperands()) { 482 if (!isGpuAsyncTokenType(operand.get())) 483 continue; 484 auto idx = operand.getOperandNumber(); 485 auto stream = operands[idx]; 486 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 487 eventRecordCallBuilder.create(loc, rewriter, {event, stream}); 488 newOperands[idx] = event; 489 streams.insert(stream); 490 } 491 for (auto stream : streams) 492 streamDestroyCallBuilder.create(loc, rewriter, {stream}); 493 494 rewriter.updateRootInPlace(yieldOp, 495 [&] { yieldOp->setOperands(newOperands); }); 496 return success(); 497 } 498 499 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`. 500 static bool isDefinedByCallTo(Value value, StringRef functionName) { 501 assert(value.getType().isa<LLVM::LLVMPointerType>()); 502 if (auto defOp = value.getDefiningOp<LLVM::CallOp>()) 503 return defOp.callee()->equals(functionName); 504 return false; 505 } 506 507 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host 508 // with the stream/event operands. The operands are destroyed. That is, it 509 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a 510 // runtime error. Eventually, we should guarantee this property. 511 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( 512 gpu::WaitOp waitOp, ArrayRef<Value> operands, 513 ConversionPatternRewriter &rewriter) const { 514 if (waitOp.asyncToken()) 515 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op."); 516 517 Location loc = waitOp.getLoc(); 518 519 for (auto operand : operands) { 520 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 521 // The converted operand's definition created a stream. 522 streamSynchronizeCallBuilder.create(loc, rewriter, {operand}); 523 streamDestroyCallBuilder.create(loc, rewriter, {operand}); 524 } else { 525 // Otherwise the converted operand is an event. This assumes that we use 526 // events in control flow code as well. 527 eventSynchronizeCallBuilder.create(loc, rewriter, {operand}); 528 eventDestroyCallBuilder.create(loc, rewriter, {operand}); 529 } 530 } 531 532 rewriter.eraseOp(waitOp); 533 return success(); 534 } 535 536 // Converts `gpu.wait async` to runtime calls. The converted op creates a new 537 // stream that is synchronized with stream/event operands. The operands are 538 // destroyed. That is, it assumes that it is not used afterwards or elsewhere. 539 // Otherwise we will get a runtime error. Eventually, we should guarantee this 540 // property. 541 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( 542 gpu::WaitOp waitOp, ArrayRef<Value> operands, 543 ConversionPatternRewriter &rewriter) const { 544 if (!waitOp.asyncToken()) 545 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op."); 546 547 Location loc = waitOp.getLoc(); 548 549 auto insertionPoint = rewriter.saveInsertionPoint(); 550 SmallVector<Value, 1> events; 551 for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) { 552 auto operand = std::get<1>(pair); 553 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 554 // The converted operand's definition created a stream. Insert an event 555 // into the stream just after the last use of the original token operand. 556 auto *defOp = std::get<0>(pair).getDefiningOp(); 557 rewriter.setInsertionPointAfter(defOp); 558 auto event = 559 eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 560 eventRecordCallBuilder.create(loc, rewriter, {event, operand}); 561 events.push_back(event); 562 } else { 563 // Otherwise the converted operand is an event. This assumes that we use 564 // events in control flow code as well. 565 events.push_back(operand); 566 } 567 } 568 rewriter.restoreInsertionPoint(insertionPoint); 569 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 570 for (auto event : events) 571 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event}); 572 for (auto event : events) 573 eventDestroyCallBuilder.create(loc, rewriter, {event}); 574 rewriter.replaceOp(waitOp, {stream}); 575 576 return success(); 577 } 578 579 // Creates a struct containing all kernel parameters on the stack and returns 580 // an array of type-erased pointers to the fields of the struct. The array can 581 // then be passed to the CUDA / ROCm (HIP) kernel launch calls. 582 // The generated code is essentially as follows: 583 // 584 // %struct = alloca(sizeof(struct { Parameters... })) 585 // %array = alloca(NumParameters * sizeof(void *)) 586 // for (i : [0, NumParameters)) 587 // %fieldPtr = llvm.getelementptr %struct[0, i] 588 // llvm.store parameters[i], %fieldPtr 589 // %elementPtr = llvm.getelementptr %array[i] 590 // llvm.store %fieldPtr, %elementPtr 591 // return %array 592 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( 593 gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands, 594 OpBuilder &builder) const { 595 auto loc = launchOp.getLoc(); 596 auto numKernelOperands = launchOp.getNumKernelOperands(); 597 auto arguments = getTypeConverter()->promoteOperands( 598 loc, launchOp.getOperands().take_back(numKernelOperands), 599 operands.take_back(numKernelOperands), builder); 600 auto numArguments = arguments.size(); 601 SmallVector<Type, 4> argumentTypes; 602 argumentTypes.reserve(numArguments); 603 for (auto argument : arguments) 604 argumentTypes.push_back(argument.getType()); 605 auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(), 606 argumentTypes); 607 auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 608 builder.getI32IntegerAttr(1)); 609 auto structPtr = builder.create<LLVM::AllocaOp>( 610 loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0); 611 auto arraySize = builder.create<LLVM::ConstantOp>( 612 loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments)); 613 auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType, 614 arraySize, /*alignment=*/0); 615 auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 616 builder.getI32IntegerAttr(0)); 617 for (auto en : llvm::enumerate(arguments)) { 618 auto index = builder.create<LLVM::ConstantOp>( 619 loc, llvmInt32Type, builder.getI32IntegerAttr(en.index())); 620 auto fieldPtr = builder.create<LLVM::GEPOp>( 621 loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr, 622 ArrayRef<Value>{zero, index.getResult()}); 623 builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr); 624 auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType, 625 arrayPtr, index.getResult()); 626 auto casted = 627 builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr); 628 builder.create<LLVM::StoreOp>(loc, casted, elementPtr); 629 } 630 return arrayPtr; 631 } 632 633 // Generates an LLVM IR dialect global that contains the name of the given 634 // kernel function as a C string, and returns a pointer to its beginning. 635 // The code is essentially: 636 // 637 // llvm.global constant @kernel_name("function_name\00") 638 // func(...) { 639 // %0 = llvm.addressof @kernel_name 640 // %1 = llvm.constant (0 : index) 641 // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*"> 642 // } 643 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant( 644 StringRef moduleName, StringRef name, Location loc, 645 OpBuilder &builder) const { 646 // Make sure the trailing zero is included in the constant. 647 std::vector<char> kernelName(name.begin(), name.end()); 648 kernelName.push_back('\0'); 649 650 std::string globalName = 651 std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name)); 652 return LLVM::createGlobalString( 653 loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()), 654 LLVM::Linkage::Internal); 655 } 656 657 // Emits LLVM IR to launch a kernel function. Expects the module that contains 658 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a 659 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR. 660 // 661 // %0 = call %binarygetter 662 // %1 = call %moduleLoad(%0) 663 // %2 = <see generateKernelNameConstant> 664 // %3 = call %moduleGetFunction(%1, %2) 665 // %4 = call %streamCreate() 666 // %5 = <see generateParamsArray> 667 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr) 668 // call %streamSynchronize(%4) 669 // call %streamDestroy(%4) 670 // call %moduleUnload(%1) 671 // 672 // If the op is async, the stream corresponds to the (single) async dependency 673 // as well as the async token the op produces. 674 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( 675 gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands, 676 ConversionPatternRewriter &rewriter) const { 677 if (failed(areAllLLVMTypes(launchOp, operands, rewriter))) 678 return failure(); 679 680 if (launchOp.asyncDependencies().size() > 1) 681 return rewriter.notifyMatchFailure( 682 launchOp, "Cannot convert with more than one async dependency."); 683 684 // Fail when the synchronous version of the op has async dependencies. The 685 // lowering destroys the stream, and we do not want to check that there is no 686 // use of the stream after this op. 687 if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty()) 688 return rewriter.notifyMatchFailure( 689 launchOp, "Cannot convert non-async op with async dependencies."); 690 691 Location loc = launchOp.getLoc(); 692 693 // Create an LLVM global with CUBIN extracted from the kernel annotation and 694 // obtain a pointer to the first byte in it. 695 auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>( 696 launchOp, launchOp.getKernelModuleName()); 697 assert(kernelModule && "expected a kernel module"); 698 699 auto binaryAttr = 700 kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation); 701 if (!binaryAttr) { 702 kernelModule.emitOpError() 703 << "missing " << gpuBinaryAnnotation << " attribute"; 704 return failure(); 705 } 706 707 SmallString<128> nameBuffer(kernelModule.getName()); 708 nameBuffer.append(kGpuBinaryStorageSuffix); 709 Value data = 710 LLVM::createGlobalString(loc, rewriter, nameBuffer.str(), 711 binaryAttr.getValue(), LLVM::Linkage::Internal); 712 713 auto module = moduleLoadCallBuilder.create(loc, rewriter, data); 714 // Get the function from the module. The name corresponds to the name of 715 // the kernel function. 716 auto kernelName = generateKernelNameConstant( 717 launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, rewriter); 718 auto function = moduleGetFunctionCallBuilder.create( 719 loc, rewriter, {module.getResult(0), kernelName}); 720 auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 721 rewriter.getI32IntegerAttr(0)); 722 auto adaptor = 723 gpu::LaunchFuncOpAdaptor(operands, launchOp->getAttrDictionary()); 724 Value stream = 725 adaptor.asyncDependencies().empty() 726 ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0) 727 : adaptor.asyncDependencies().front(); 728 // Create array of pointers to kernel arguments. 729 auto kernelParams = generateParamsArray(launchOp, operands, rewriter); 730 auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType); 731 launchKernelCallBuilder.create(loc, rewriter, 732 {function.getResult(0), launchOp.gridSizeX(), 733 launchOp.gridSizeY(), launchOp.gridSizeZ(), 734 launchOp.blockSizeX(), launchOp.blockSizeY(), 735 launchOp.blockSizeZ(), 736 /*sharedMemBytes=*/zero, stream, kernelParams, 737 /*extra=*/nullpointer}); 738 739 if (launchOp.asyncToken()) { 740 // Async launch: make dependent ops use the same stream. 741 rewriter.replaceOp(launchOp, {stream}); 742 } else { 743 // Synchronize with host and destroy stream. This must be the stream created 744 // above (with no other uses) because we check that the synchronous version 745 // does not have any async dependencies. 746 streamSynchronizeCallBuilder.create(loc, rewriter, stream); 747 streamDestroyCallBuilder.create(loc, rewriter, stream); 748 rewriter.eraseOp(launchOp); 749 } 750 moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0)); 751 752 return success(); 753 } 754 755 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( 756 gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands, 757 ConversionPatternRewriter &rewriter) const { 758 auto memRefType = memcpyOp.src().getType().cast<MemRefType>(); 759 760 if (failed(areAllLLVMTypes(memcpyOp, operands, rewriter)) || 761 !isConvertibleAndHasIdentityMaps(memRefType) || 762 failed(isAsyncWithOneDependency(rewriter, memcpyOp))) 763 return failure(); 764 765 auto loc = memcpyOp.getLoc(); 766 auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary()); 767 768 MemRefDescriptor srcDesc(adaptor.src()); 769 770 Value numElements = 771 memRefType.hasStaticShape() 772 ? createIndexConstant(rewriter, loc, memRefType.getNumElements()) 773 // For identity layouts (verified above), the number of elements is 774 // stride[0] * size[0]. 775 : rewriter.create<LLVM::MulOp>(loc, srcDesc.stride(rewriter, loc, 0), 776 srcDesc.size(rewriter, loc, 0)); 777 778 Type elementPtrType = getElementPtrType(memRefType); 779 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType); 780 Value gepPtr = rewriter.create<LLVM::GEPOp>( 781 loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements}); 782 auto sizeBytes = 783 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); 784 785 auto src = rewriter.create<LLVM::BitcastOp>( 786 loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc)); 787 auto dst = rewriter.create<LLVM::BitcastOp>( 788 loc, llvmPointerType, 789 MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc)); 790 791 auto stream = adaptor.asyncDependencies().front(); 792 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream}); 793 794 rewriter.replaceOp(memcpyOp, {stream}); 795 796 return success(); 797 } 798 799 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 800 mlir::createGpuToLLVMConversionPass() { 801 return std::make_unique<GpuToLLVMConversionPass>(); 802 } 803