1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to convert gpu.launch_func op into a sequence of 10 // GPU runtime calls. As most of GPU runtimes does not have a stable published 11 // ABI, this pass uses a slim runtime layer that builds on top of the public 12 // API from GPU runtime headers. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h" 17 18 #include "../PassDetail.h" 19 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h" 20 #include "mlir/Conversion/LLVMCommon/Pattern.h" 21 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" 22 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" 23 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" 24 #include "mlir/Dialect/Async/IR/Async.h" 25 #include "mlir/Dialect/GPU/GPUDialect.h" 26 #include "mlir/Dialect/GPU/Passes.h" 27 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 28 #include "mlir/IR/Attributes.h" 29 #include "mlir/IR/Builders.h" 30 #include "mlir/IR/BuiltinOps.h" 31 #include "mlir/IR/BuiltinTypes.h" 32 33 #include "llvm/ADT/STLExtras.h" 34 #include "llvm/Support/Error.h" 35 #include "llvm/Support/FormatVariadic.h" 36 37 using namespace mlir; 38 39 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst"; 40 41 namespace { 42 43 class GpuToLLVMConversionPass 44 : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> { 45 public: 46 GpuToLLVMConversionPass() = default; 47 48 GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other) 49 : GpuToLLVMConversionPassBase(other) {} 50 51 // Run the dialect converter on the module. 52 void runOnOperation() override; 53 54 private: 55 Option<std::string> gpuBinaryAnnotation{ 56 *this, "gpu-binary-annotation", 57 llvm::cl::desc("Annotation attribute string for GPU binary"), 58 llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())}; 59 }; 60 61 struct FunctionCallBuilder { 62 FunctionCallBuilder(StringRef functionName, Type returnType, 63 ArrayRef<Type> argumentTypes) 64 : functionName(functionName), 65 functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {} 66 LLVM::CallOp create(Location loc, OpBuilder &builder, 67 ArrayRef<Value> arguments) const; 68 69 StringRef functionName; 70 LLVM::LLVMFunctionType functionType; 71 }; 72 73 template <typename OpTy> 74 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> { 75 public: 76 explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 77 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {} 78 79 protected: 80 MLIRContext *context = &this->getTypeConverter()->getContext(); 81 82 Type llvmVoidType = LLVM::LLVMVoidType::get(context); 83 Type llvmPointerType = 84 LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 85 Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType); 86 Type llvmInt8Type = IntegerType::get(context, 8); 87 Type llvmInt32Type = IntegerType::get(context, 32); 88 Type llvmInt64Type = IntegerType::get(context, 64); 89 Type llvmIntPtrType = IntegerType::get( 90 context, this->getTypeConverter()->getPointerBitwidth(0)); 91 92 FunctionCallBuilder moduleLoadCallBuilder = { 93 "mgpuModuleLoad", 94 llvmPointerType /* void *module */, 95 {llvmPointerType /* void *cubin */}}; 96 FunctionCallBuilder moduleUnloadCallBuilder = { 97 "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}}; 98 FunctionCallBuilder moduleGetFunctionCallBuilder = { 99 "mgpuModuleGetFunction", 100 llvmPointerType /* void *function */, 101 { 102 llvmPointerType, /* void *module */ 103 llvmPointerType /* char *name */ 104 }}; 105 FunctionCallBuilder launchKernelCallBuilder = { 106 "mgpuLaunchKernel", 107 llvmVoidType, 108 { 109 llvmPointerType, /* void* f */ 110 llvmIntPtrType, /* intptr_t gridXDim */ 111 llvmIntPtrType, /* intptr_t gridyDim */ 112 llvmIntPtrType, /* intptr_t gridZDim */ 113 llvmIntPtrType, /* intptr_t blockXDim */ 114 llvmIntPtrType, /* intptr_t blockYDim */ 115 llvmIntPtrType, /* intptr_t blockZDim */ 116 llvmInt32Type, /* unsigned int sharedMemBytes */ 117 llvmPointerType, /* void *hstream */ 118 llvmPointerPointerType, /* void **kernelParams */ 119 llvmPointerPointerType /* void **extra */ 120 }}; 121 FunctionCallBuilder streamCreateCallBuilder = { 122 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}}; 123 FunctionCallBuilder streamDestroyCallBuilder = { 124 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}}; 125 FunctionCallBuilder streamSynchronizeCallBuilder = { 126 "mgpuStreamSynchronize", 127 llvmVoidType, 128 {llvmPointerType /* void *stream */}}; 129 FunctionCallBuilder streamWaitEventCallBuilder = { 130 "mgpuStreamWaitEvent", 131 llvmVoidType, 132 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}}; 133 FunctionCallBuilder eventCreateCallBuilder = { 134 "mgpuEventCreate", llvmPointerType /* void *event */, {}}; 135 FunctionCallBuilder eventDestroyCallBuilder = { 136 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}}; 137 FunctionCallBuilder eventSynchronizeCallBuilder = { 138 "mgpuEventSynchronize", 139 llvmVoidType, 140 {llvmPointerType /* void *event */}}; 141 FunctionCallBuilder eventRecordCallBuilder = { 142 "mgpuEventRecord", 143 llvmVoidType, 144 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}}; 145 FunctionCallBuilder hostRegisterCallBuilder = { 146 "mgpuMemHostRegisterMemRef", 147 llvmVoidType, 148 {llvmIntPtrType /* intptr_t rank */, 149 llvmPointerType /* void *memrefDesc */, 150 llvmIntPtrType /* intptr_t elementSizeBytes */}}; 151 FunctionCallBuilder allocCallBuilder = { 152 "mgpuMemAlloc", 153 llvmPointerType /* void * */, 154 {llvmIntPtrType /* intptr_t sizeBytes */, 155 llvmPointerType /* void *stream */}}; 156 FunctionCallBuilder deallocCallBuilder = { 157 "mgpuMemFree", 158 llvmVoidType, 159 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}}; 160 FunctionCallBuilder memcpyCallBuilder = { 161 "mgpuMemcpy", 162 llvmVoidType, 163 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */, 164 llvmIntPtrType /* intptr_t sizeBytes */, 165 llvmPointerType /* void *stream */}}; 166 }; 167 168 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime 169 /// call. Currently it supports CUDA and ROCm (HIP). 170 class ConvertHostRegisterOpToGpuRuntimeCallPattern 171 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> { 172 public: 173 ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 174 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {} 175 176 private: 177 LogicalResult 178 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands, 179 ConversionPatternRewriter &rewriter) const override; 180 }; 181 182 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime 183 /// call. Currently it supports CUDA and ROCm (HIP). 184 class ConvertAllocOpToGpuRuntimeCallPattern 185 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> { 186 public: 187 ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 188 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {} 189 190 private: 191 LogicalResult 192 matchAndRewrite(gpu::AllocOp allocOp, ArrayRef<Value> operands, 193 ConversionPatternRewriter &rewriter) const override; 194 }; 195 196 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime 197 /// call. Currently it supports CUDA and ROCm (HIP). 198 class ConvertDeallocOpToGpuRuntimeCallPattern 199 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> { 200 public: 201 ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 202 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {} 203 204 private: 205 LogicalResult 206 matchAndRewrite(gpu::DeallocOp deallocOp, ArrayRef<Value> operands, 207 ConversionPatternRewriter &rewriter) const override; 208 }; 209 210 class ConvertAsyncYieldToGpuRuntimeCallPattern 211 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> { 212 public: 213 ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 214 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {} 215 216 private: 217 LogicalResult 218 matchAndRewrite(async::YieldOp yieldOp, ArrayRef<Value> operands, 219 ConversionPatternRewriter &rewriter) const override; 220 }; 221 222 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime 223 /// call. Currently it supports CUDA and ROCm (HIP). 224 class ConvertWaitOpToGpuRuntimeCallPattern 225 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 226 public: 227 ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 228 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 229 230 private: 231 LogicalResult 232 matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands, 233 ConversionPatternRewriter &rewriter) const override; 234 }; 235 236 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime 237 /// call. Currently it supports CUDA and ROCm (HIP). 238 class ConvertWaitAsyncOpToGpuRuntimeCallPattern 239 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> { 240 public: 241 ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 242 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {} 243 244 private: 245 LogicalResult 246 matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands, 247 ConversionPatternRewriter &rewriter) const override; 248 }; 249 250 /// A rewrite patter to convert gpu.launch_func operations into a sequence of 251 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP). 252 /// 253 /// In essence, a gpu.launch_func operations gets compiled into the following 254 /// sequence of runtime calls: 255 /// 256 /// * moduleLoad -- loads the module given the cubin / hsaco data 257 /// * moduleGetFunction -- gets a handle to the actual kernel function 258 /// * getStreamHelper -- initializes a new compute stream on GPU 259 /// * launchKernel -- launches the kernel on a stream 260 /// * streamSynchronize -- waits for operations on the stream to finish 261 /// 262 /// Intermediate data structures are allocated on the stack. 263 class ConvertLaunchFuncOpToGpuRuntimeCallPattern 264 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> { 265 public: 266 ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter, 267 StringRef gpuBinaryAnnotation) 268 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter), 269 gpuBinaryAnnotation(gpuBinaryAnnotation) {} 270 271 private: 272 Value generateParamsArray(gpu::LaunchFuncOp launchOp, 273 ArrayRef<Value> operands, OpBuilder &builder) const; 274 Value generateKernelNameConstant(StringRef moduleName, StringRef name, 275 Location loc, OpBuilder &builder) const; 276 277 LogicalResult 278 matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands, 279 ConversionPatternRewriter &rewriter) const override; 280 281 llvm::SmallString<32> gpuBinaryAnnotation; 282 }; 283 284 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> { 285 using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern; 286 287 LogicalResult matchAndRewrite(gpu::GPUModuleOp op, 288 PatternRewriter &rewriter) const override { 289 // GPU kernel modules are no longer necessary since we have a global 290 // constant with the CUBIN, or HSACO data. 291 rewriter.eraseOp(op); 292 return success(); 293 } 294 }; 295 296 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime 297 /// call. Currently it supports CUDA and ROCm (HIP). 298 class ConvertMemcpyOpToGpuRuntimeCallPattern 299 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> { 300 public: 301 ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter) 302 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {} 303 304 private: 305 LogicalResult 306 matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands, 307 ConversionPatternRewriter &rewriter) const override; 308 }; 309 } // namespace 310 311 void GpuToLLVMConversionPass::runOnOperation() { 312 LLVMTypeConverter converter(&getContext()); 313 RewritePatternSet patterns(&getContext()); 314 LLVMConversionTarget target(getContext()); 315 316 target.addIllegalDialect<gpu::GPUDialect>(); 317 318 populateVectorToLLVMConversionPatterns(converter, patterns); 319 populateStdToLLVMConversionPatterns(converter, patterns); 320 populateAsyncStructuralTypeConversionsAndLegality(converter, patterns, 321 target); 322 323 converter.addConversion( 324 [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type { 325 return LLVM::LLVMPointerType::get(IntegerType::get(context, 8)); 326 }); 327 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern, 328 ConvertDeallocOpToGpuRuntimeCallPattern, 329 ConvertHostRegisterOpToGpuRuntimeCallPattern, 330 ConvertMemcpyOpToGpuRuntimeCallPattern, 331 ConvertWaitAsyncOpToGpuRuntimeCallPattern, 332 ConvertWaitOpToGpuRuntimeCallPattern, 333 ConvertAsyncYieldToGpuRuntimeCallPattern>(converter); 334 patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter, 335 gpuBinaryAnnotation); 336 patterns.add<EraseGpuModuleOpPattern>(&converter.getContext()); 337 338 if (failed( 339 applyPartialConversion(getOperation(), target, std::move(patterns)))) 340 signalPassFailure(); 341 } 342 343 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder, 344 ArrayRef<Value> arguments) const { 345 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>(); 346 auto function = [&] { 347 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName)) 348 return function; 349 return OpBuilder::atBlockEnd(module.getBody()) 350 .create<LLVM::LLVMFuncOp>(loc, functionName, functionType); 351 }(); 352 return builder.create<LLVM::CallOp>( 353 loc, const_cast<LLVM::LLVMFunctionType &>(functionType).getReturnType(), 354 builder.getSymbolRefAttr(function), arguments); 355 } 356 357 // Returns whether all operands are of LLVM type. 358 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands, 359 ConversionPatternRewriter &rewriter) { 360 if (!llvm::all_of(operands, [](Value value) { 361 return LLVM::isCompatibleType(value.getType()); 362 })) 363 return rewriter.notifyMatchFailure( 364 op, "Cannot convert if operands aren't of LLVM type."); 365 return success(); 366 } 367 368 static LogicalResult 369 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter, 370 gpu::AsyncOpInterface op) { 371 if (op.getAsyncDependencies().size() != 1) 372 return rewriter.notifyMatchFailure( 373 op, "Can only convert with exactly one async dependency."); 374 375 if (!op.getAsyncToken()) 376 return rewriter.notifyMatchFailure(op, "Can convert only async version."); 377 378 return success(); 379 } 380 381 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite( 382 gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands, 383 ConversionPatternRewriter &rewriter) const { 384 auto *op = hostRegisterOp.getOperation(); 385 if (failed(areAllLLVMTypes(op, operands, rewriter))) 386 return failure(); 387 388 Location loc = op->getLoc(); 389 390 auto memRefType = hostRegisterOp.value().getType(); 391 auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType(); 392 auto elementSize = getSizeInBytes(loc, elementType, rewriter); 393 394 auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(), 395 operands, rewriter); 396 arguments.push_back(elementSize); 397 hostRegisterCallBuilder.create(loc, rewriter, arguments); 398 399 rewriter.eraseOp(op); 400 return success(); 401 } 402 403 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite( 404 gpu::AllocOp allocOp, ArrayRef<Value> operands, 405 ConversionPatternRewriter &rewriter) const { 406 MemRefType memRefType = allocOp.getType(); 407 408 if (failed(areAllLLVMTypes(allocOp, operands, rewriter)) || 409 !isConvertibleAndHasIdentityMaps(memRefType) || 410 failed(isAsyncWithOneDependency(rewriter, allocOp))) 411 return failure(); 412 413 auto loc = allocOp.getLoc(); 414 auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary()); 415 416 // Get shape of the memref as values: static sizes are constant 417 // values and dynamic sizes are passed to 'alloc' as operands. 418 SmallVector<Value, 4> shape; 419 SmallVector<Value, 4> strides; 420 Value sizeBytes; 421 getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter, 422 shape, strides, sizeBytes); 423 424 // Allocate the underlying buffer and store a pointer to it in the MemRef 425 // descriptor. 426 Type elementPtrType = this->getElementPtrType(memRefType); 427 auto stream = adaptor.asyncDependencies().front(); 428 Value allocatedPtr = 429 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0); 430 allocatedPtr = 431 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr); 432 433 // No alignment. 434 Value alignedPtr = allocatedPtr; 435 436 // Create the MemRef descriptor. 437 auto memRefDescriptor = this->createMemRefDescriptor( 438 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter); 439 440 rewriter.replaceOp(allocOp, {memRefDescriptor, stream}); 441 442 return success(); 443 } 444 445 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite( 446 gpu::DeallocOp deallocOp, ArrayRef<Value> operands, 447 ConversionPatternRewriter &rewriter) const { 448 if (failed(areAllLLVMTypes(deallocOp, operands, rewriter)) || 449 failed(isAsyncWithOneDependency(rewriter, deallocOp))) 450 return failure(); 451 452 Location loc = deallocOp.getLoc(); 453 454 auto adaptor = 455 gpu::DeallocOpAdaptor(operands, deallocOp->getAttrDictionary()); 456 Value pointer = 457 MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc); 458 auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer); 459 Value stream = adaptor.asyncDependencies().front(); 460 deallocCallBuilder.create(loc, rewriter, {casted, stream}); 461 462 rewriter.replaceOp(deallocOp, {stream}); 463 return success(); 464 } 465 466 static bool isGpuAsyncTokenType(Value value) { 467 return value.getType().isa<gpu::AsyncTokenType>(); 468 } 469 470 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The 471 // !gpu.async.token are lowered to stream within the async.execute region, but 472 // are passed as events between them. For each !gpu.async.token operand, we 473 // create an event and record it on the stream. 474 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite( 475 async::YieldOp yieldOp, ArrayRef<Value> operands, 476 ConversionPatternRewriter &rewriter) const { 477 if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType)) 478 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand"); 479 480 Location loc = yieldOp.getLoc(); 481 SmallVector<Value, 4> newOperands(operands.begin(), operands.end()); 482 llvm::SmallDenseSet<Value> streams; 483 for (auto &operand : yieldOp->getOpOperands()) { 484 if (!isGpuAsyncTokenType(operand.get())) 485 continue; 486 auto idx = operand.getOperandNumber(); 487 auto stream = operands[idx]; 488 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 489 eventRecordCallBuilder.create(loc, rewriter, {event, stream}); 490 newOperands[idx] = event; 491 streams.insert(stream); 492 } 493 for (auto stream : streams) 494 streamDestroyCallBuilder.create(loc, rewriter, {stream}); 495 496 rewriter.updateRootInPlace(yieldOp, 497 [&] { yieldOp->setOperands(newOperands); }); 498 return success(); 499 } 500 501 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`. 502 static bool isDefinedByCallTo(Value value, StringRef functionName) { 503 assert(value.getType().isa<LLVM::LLVMPointerType>()); 504 if (auto defOp = value.getDefiningOp<LLVM::CallOp>()) 505 return defOp.callee()->equals(functionName); 506 return false; 507 } 508 509 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host 510 // with the stream/event operands. The operands are destroyed. That is, it 511 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a 512 // runtime error. Eventually, we should guarantee this property. 513 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite( 514 gpu::WaitOp waitOp, ArrayRef<Value> operands, 515 ConversionPatternRewriter &rewriter) const { 516 if (waitOp.asyncToken()) 517 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op."); 518 519 Location loc = waitOp.getLoc(); 520 521 for (auto operand : operands) { 522 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 523 // The converted operand's definition created a stream. 524 streamSynchronizeCallBuilder.create(loc, rewriter, {operand}); 525 streamDestroyCallBuilder.create(loc, rewriter, {operand}); 526 } else { 527 // Otherwise the converted operand is an event. This assumes that we use 528 // events in control flow code as well. 529 eventSynchronizeCallBuilder.create(loc, rewriter, {operand}); 530 eventDestroyCallBuilder.create(loc, rewriter, {operand}); 531 } 532 } 533 534 rewriter.eraseOp(waitOp); 535 return success(); 536 } 537 538 // Converts `gpu.wait async` to runtime calls. The converted op creates a new 539 // stream that is synchronized with stream/event operands. The operands are 540 // destroyed. That is, it assumes that it is not used afterwards or elsewhere. 541 // Otherwise we will get a runtime error. Eventually, we should guarantee this 542 // property. 543 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite( 544 gpu::WaitOp waitOp, ArrayRef<Value> operands, 545 ConversionPatternRewriter &rewriter) const { 546 if (!waitOp.asyncToken()) 547 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op."); 548 549 Location loc = waitOp.getLoc(); 550 551 auto insertionPoint = rewriter.saveInsertionPoint(); 552 SmallVector<Value, 1> events; 553 for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) { 554 auto operand = std::get<1>(pair); 555 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) { 556 // The converted operand's definition created a stream. Insert an event 557 // into the stream just after the last use of the original token operand. 558 auto *defOp = std::get<0>(pair).getDefiningOp(); 559 rewriter.setInsertionPointAfter(defOp); 560 auto event = 561 eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 562 eventRecordCallBuilder.create(loc, rewriter, {event, operand}); 563 events.push_back(event); 564 } else { 565 // Otherwise the converted operand is an event. This assumes that we use 566 // events in control flow code as well. 567 events.push_back(operand); 568 } 569 } 570 rewriter.restoreInsertionPoint(insertionPoint); 571 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0); 572 for (auto event : events) 573 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event}); 574 for (auto event : events) 575 eventDestroyCallBuilder.create(loc, rewriter, {event}); 576 rewriter.replaceOp(waitOp, {stream}); 577 578 return success(); 579 } 580 581 // Creates a struct containing all kernel parameters on the stack and returns 582 // an array of type-erased pointers to the fields of the struct. The array can 583 // then be passed to the CUDA / ROCm (HIP) kernel launch calls. 584 // The generated code is essentially as follows: 585 // 586 // %struct = alloca(sizeof(struct { Parameters... })) 587 // %array = alloca(NumParameters * sizeof(void *)) 588 // for (i : [0, NumParameters)) 589 // %fieldPtr = llvm.getelementptr %struct[0, i] 590 // llvm.store parameters[i], %fieldPtr 591 // %elementPtr = llvm.getelementptr %array[i] 592 // llvm.store %fieldPtr, %elementPtr 593 // return %array 594 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray( 595 gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands, 596 OpBuilder &builder) const { 597 auto loc = launchOp.getLoc(); 598 auto numKernelOperands = launchOp.getNumKernelOperands(); 599 auto arguments = getTypeConverter()->promoteOperands( 600 loc, launchOp.getOperands().take_back(numKernelOperands), 601 operands.take_back(numKernelOperands), builder); 602 auto numArguments = arguments.size(); 603 SmallVector<Type, 4> argumentTypes; 604 argumentTypes.reserve(numArguments); 605 for (auto argument : arguments) 606 argumentTypes.push_back(argument.getType()); 607 auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(), 608 argumentTypes); 609 auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 610 builder.getI32IntegerAttr(1)); 611 auto structPtr = builder.create<LLVM::AllocaOp>( 612 loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0); 613 auto arraySize = builder.create<LLVM::ConstantOp>( 614 loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments)); 615 auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType, 616 arraySize, /*alignment=*/0); 617 auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type, 618 builder.getI32IntegerAttr(0)); 619 for (auto en : llvm::enumerate(arguments)) { 620 auto index = builder.create<LLVM::ConstantOp>( 621 loc, llvmInt32Type, builder.getI32IntegerAttr(en.index())); 622 auto fieldPtr = builder.create<LLVM::GEPOp>( 623 loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr, 624 ArrayRef<Value>{zero, index.getResult()}); 625 builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr); 626 auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType, 627 arrayPtr, index.getResult()); 628 auto casted = 629 builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr); 630 builder.create<LLVM::StoreOp>(loc, casted, elementPtr); 631 } 632 return arrayPtr; 633 } 634 635 // Generates an LLVM IR dialect global that contains the name of the given 636 // kernel function as a C string, and returns a pointer to its beginning. 637 // The code is essentially: 638 // 639 // llvm.global constant @kernel_name("function_name\00") 640 // func(...) { 641 // %0 = llvm.addressof @kernel_name 642 // %1 = llvm.constant (0 : index) 643 // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*"> 644 // } 645 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant( 646 StringRef moduleName, StringRef name, Location loc, 647 OpBuilder &builder) const { 648 // Make sure the trailing zero is included in the constant. 649 std::vector<char> kernelName(name.begin(), name.end()); 650 kernelName.push_back('\0'); 651 652 std::string globalName = 653 std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name)); 654 return LLVM::createGlobalString( 655 loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()), 656 LLVM::Linkage::Internal); 657 } 658 659 // Emits LLVM IR to launch a kernel function. Expects the module that contains 660 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a 661 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR. 662 // 663 // %0 = call %binarygetter 664 // %1 = call %moduleLoad(%0) 665 // %2 = <see generateKernelNameConstant> 666 // %3 = call %moduleGetFunction(%1, %2) 667 // %4 = call %streamCreate() 668 // %5 = <see generateParamsArray> 669 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr) 670 // call %streamSynchronize(%4) 671 // call %streamDestroy(%4) 672 // call %moduleUnload(%1) 673 // 674 // If the op is async, the stream corresponds to the (single) async dependency 675 // as well as the async token the op produces. 676 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite( 677 gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands, 678 ConversionPatternRewriter &rewriter) const { 679 if (failed(areAllLLVMTypes(launchOp, operands, rewriter))) 680 return failure(); 681 682 if (launchOp.asyncDependencies().size() > 1) 683 return rewriter.notifyMatchFailure( 684 launchOp, "Cannot convert with more than one async dependency."); 685 686 // Fail when the synchronous version of the op has async dependencies. The 687 // lowering destroys the stream, and we do not want to check that there is no 688 // use of the stream after this op. 689 if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty()) 690 return rewriter.notifyMatchFailure( 691 launchOp, "Cannot convert non-async op with async dependencies."); 692 693 Location loc = launchOp.getLoc(); 694 695 // Create an LLVM global with CUBIN extracted from the kernel annotation and 696 // obtain a pointer to the first byte in it. 697 auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>( 698 launchOp, launchOp.getKernelModuleName()); 699 assert(kernelModule && "expected a kernel module"); 700 701 auto binaryAttr = 702 kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation); 703 if (!binaryAttr) { 704 kernelModule.emitOpError() 705 << "missing " << gpuBinaryAnnotation << " attribute"; 706 return failure(); 707 } 708 709 SmallString<128> nameBuffer(kernelModule.getName()); 710 nameBuffer.append(kGpuBinaryStorageSuffix); 711 Value data = 712 LLVM::createGlobalString(loc, rewriter, nameBuffer.str(), 713 binaryAttr.getValue(), LLVM::Linkage::Internal); 714 715 auto module = moduleLoadCallBuilder.create(loc, rewriter, data); 716 // Get the function from the module. The name corresponds to the name of 717 // the kernel function. 718 auto kernelName = generateKernelNameConstant( 719 launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, rewriter); 720 auto function = moduleGetFunctionCallBuilder.create( 721 loc, rewriter, {module.getResult(0), kernelName}); 722 auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type, 723 rewriter.getI32IntegerAttr(0)); 724 auto adaptor = 725 gpu::LaunchFuncOpAdaptor(operands, launchOp->getAttrDictionary()); 726 Value stream = 727 adaptor.asyncDependencies().empty() 728 ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0) 729 : adaptor.asyncDependencies().front(); 730 // Create array of pointers to kernel arguments. 731 auto kernelParams = generateParamsArray(launchOp, operands, rewriter); 732 auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType); 733 launchKernelCallBuilder.create(loc, rewriter, 734 {function.getResult(0), launchOp.gridSizeX(), 735 launchOp.gridSizeY(), launchOp.gridSizeZ(), 736 launchOp.blockSizeX(), launchOp.blockSizeY(), 737 launchOp.blockSizeZ(), 738 /*sharedMemBytes=*/zero, stream, kernelParams, 739 /*extra=*/nullpointer}); 740 741 if (launchOp.asyncToken()) { 742 // Async launch: make dependent ops use the same stream. 743 rewriter.replaceOp(launchOp, {stream}); 744 } else { 745 // Synchronize with host and destroy stream. This must be the stream created 746 // above (with no other uses) because we check that the synchronous version 747 // does not have any async dependencies. 748 streamSynchronizeCallBuilder.create(loc, rewriter, stream); 749 streamDestroyCallBuilder.create(loc, rewriter, stream); 750 rewriter.eraseOp(launchOp); 751 } 752 moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0)); 753 754 return success(); 755 } 756 757 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite( 758 gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands, 759 ConversionPatternRewriter &rewriter) const { 760 auto memRefType = memcpyOp.src().getType().cast<MemRefType>(); 761 762 if (failed(areAllLLVMTypes(memcpyOp, operands, rewriter)) || 763 !isConvertibleAndHasIdentityMaps(memRefType) || 764 failed(isAsyncWithOneDependency(rewriter, memcpyOp))) 765 return failure(); 766 767 auto loc = memcpyOp.getLoc(); 768 auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary()); 769 770 MemRefDescriptor srcDesc(adaptor.src()); 771 772 Value numElements = 773 memRefType.hasStaticShape() 774 ? createIndexConstant(rewriter, loc, memRefType.getNumElements()) 775 // For identity layouts (verified above), the number of elements is 776 // stride[0] * size[0]. 777 : rewriter.create<LLVM::MulOp>(loc, srcDesc.stride(rewriter, loc, 0), 778 srcDesc.size(rewriter, loc, 0)); 779 780 Type elementPtrType = getElementPtrType(memRefType); 781 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType); 782 Value gepPtr = rewriter.create<LLVM::GEPOp>( 783 loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements}); 784 auto sizeBytes = 785 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr); 786 787 auto src = rewriter.create<LLVM::BitcastOp>( 788 loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc)); 789 auto dst = rewriter.create<LLVM::BitcastOp>( 790 loc, llvmPointerType, 791 MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc)); 792 793 auto stream = adaptor.asyncDependencies().front(); 794 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream}); 795 796 rewriter.replaceOp(memcpyOp, {stream}); 797 798 return success(); 799 } 800 801 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> 802 mlir::createGpuToLLVMConversionPass() { 803 return std::make_unique<GpuToLLVMConversionPass>(); 804 } 805