1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17 
18 #include "../PassDetail.h"
19 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
20 #include "mlir/Conversion/LLVMCommon/Pattern.h"
21 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
22 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
23 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
24 #include "mlir/Dialect/Async/IR/Async.h"
25 #include "mlir/Dialect/GPU/GPUDialect.h"
26 #include "mlir/Dialect/GPU/Passes.h"
27 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
28 #include "mlir/IR/Attributes.h"
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/BuiltinTypes.h"
32 
33 #include "llvm/ADT/STLExtras.h"
34 #include "llvm/Support/Error.h"
35 #include "llvm/Support/FormatVariadic.h"
36 
37 using namespace mlir;
38 
39 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
40 
41 namespace {
42 
43 class GpuToLLVMConversionPass
44     : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
45 public:
46   GpuToLLVMConversionPass() = default;
47 
48   GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
49       : GpuToLLVMConversionPassBase(other) {}
50 
51   // Run the dialect converter on the module.
52   void runOnOperation() override;
53 
54 private:
55   Option<std::string> gpuBinaryAnnotation{
56       *this, "gpu-binary-annotation",
57       llvm::cl::desc("Annotation attribute string for GPU binary"),
58       llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
59 };
60 
61 struct FunctionCallBuilder {
62   FunctionCallBuilder(StringRef functionName, Type returnType,
63                       ArrayRef<Type> argumentTypes)
64       : functionName(functionName),
65         functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
66   LLVM::CallOp create(Location loc, OpBuilder &builder,
67                       ArrayRef<Value> arguments) const;
68 
69   StringRef functionName;
70   LLVM::LLVMFunctionType functionType;
71 };
72 
73 template <typename OpTy>
74 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
75 public:
76   explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
77       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
78 
79 protected:
80   MLIRContext *context = &this->getTypeConverter()->getContext();
81 
82   Type llvmVoidType = LLVM::LLVMVoidType::get(context);
83   Type llvmPointerType =
84       LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
85   Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType);
86   Type llvmInt8Type = IntegerType::get(context, 8);
87   Type llvmInt32Type = IntegerType::get(context, 32);
88   Type llvmInt64Type = IntegerType::get(context, 64);
89   Type llvmIntPtrType = IntegerType::get(
90       context, this->getTypeConverter()->getPointerBitwidth(0));
91 
92   FunctionCallBuilder moduleLoadCallBuilder = {
93       "mgpuModuleLoad",
94       llvmPointerType /* void *module */,
95       {llvmPointerType /* void *cubin */}};
96   FunctionCallBuilder moduleUnloadCallBuilder = {
97       "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
98   FunctionCallBuilder moduleGetFunctionCallBuilder = {
99       "mgpuModuleGetFunction",
100       llvmPointerType /* void *function */,
101       {
102           llvmPointerType, /* void *module */
103           llvmPointerType  /* char *name   */
104       }};
105   FunctionCallBuilder launchKernelCallBuilder = {
106       "mgpuLaunchKernel",
107       llvmVoidType,
108       {
109           llvmPointerType,        /* void* f */
110           llvmIntPtrType,         /* intptr_t gridXDim */
111           llvmIntPtrType,         /* intptr_t gridyDim */
112           llvmIntPtrType,         /* intptr_t gridZDim */
113           llvmIntPtrType,         /* intptr_t blockXDim */
114           llvmIntPtrType,         /* intptr_t blockYDim */
115           llvmIntPtrType,         /* intptr_t blockZDim */
116           llvmInt32Type,          /* unsigned int sharedMemBytes */
117           llvmPointerType,        /* void *hstream */
118           llvmPointerPointerType, /* void **kernelParams */
119           llvmPointerPointerType  /* void **extra */
120       }};
121   FunctionCallBuilder streamCreateCallBuilder = {
122       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
123   FunctionCallBuilder streamDestroyCallBuilder = {
124       "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
125   FunctionCallBuilder streamSynchronizeCallBuilder = {
126       "mgpuStreamSynchronize",
127       llvmVoidType,
128       {llvmPointerType /* void *stream */}};
129   FunctionCallBuilder streamWaitEventCallBuilder = {
130       "mgpuStreamWaitEvent",
131       llvmVoidType,
132       {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
133   FunctionCallBuilder eventCreateCallBuilder = {
134       "mgpuEventCreate", llvmPointerType /* void *event */, {}};
135   FunctionCallBuilder eventDestroyCallBuilder = {
136       "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
137   FunctionCallBuilder eventSynchronizeCallBuilder = {
138       "mgpuEventSynchronize",
139       llvmVoidType,
140       {llvmPointerType /* void *event */}};
141   FunctionCallBuilder eventRecordCallBuilder = {
142       "mgpuEventRecord",
143       llvmVoidType,
144       {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
145   FunctionCallBuilder hostRegisterCallBuilder = {
146       "mgpuMemHostRegisterMemRef",
147       llvmVoidType,
148       {llvmIntPtrType /* intptr_t rank */,
149        llvmPointerType /* void *memrefDesc */,
150        llvmIntPtrType /* intptr_t elementSizeBytes */}};
151   FunctionCallBuilder allocCallBuilder = {
152       "mgpuMemAlloc",
153       llvmPointerType /* void * */,
154       {llvmIntPtrType /* intptr_t sizeBytes */,
155        llvmPointerType /* void *stream */}};
156   FunctionCallBuilder deallocCallBuilder = {
157       "mgpuMemFree",
158       llvmVoidType,
159       {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
160   FunctionCallBuilder memcpyCallBuilder = {
161       "mgpuMemcpy",
162       llvmVoidType,
163       {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
164        llvmIntPtrType /* intptr_t sizeBytes */,
165        llvmPointerType /* void *stream */}};
166 };
167 
168 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
169 /// call. Currently it supports CUDA and ROCm (HIP).
170 class ConvertHostRegisterOpToGpuRuntimeCallPattern
171     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
172 public:
173   ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
174       : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
175 
176 private:
177   LogicalResult
178   matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
179                   ConversionPatternRewriter &rewriter) const override;
180 };
181 
182 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
183 /// call. Currently it supports CUDA and ROCm (HIP).
184 class ConvertAllocOpToGpuRuntimeCallPattern
185     : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
186 public:
187   ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
188       : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
189 
190 private:
191   LogicalResult
192   matchAndRewrite(gpu::AllocOp allocOp, ArrayRef<Value> operands,
193                   ConversionPatternRewriter &rewriter) const override;
194 };
195 
196 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
197 /// call. Currently it supports CUDA and ROCm (HIP).
198 class ConvertDeallocOpToGpuRuntimeCallPattern
199     : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
200 public:
201   ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
202       : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
203 
204 private:
205   LogicalResult
206   matchAndRewrite(gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
207                   ConversionPatternRewriter &rewriter) const override;
208 };
209 
210 class ConvertAsyncYieldToGpuRuntimeCallPattern
211     : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
212 public:
213   ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
214       : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
215 
216 private:
217   LogicalResult
218   matchAndRewrite(async::YieldOp yieldOp, ArrayRef<Value> operands,
219                   ConversionPatternRewriter &rewriter) const override;
220 };
221 
222 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
223 /// call. Currently it supports CUDA and ROCm (HIP).
224 class ConvertWaitOpToGpuRuntimeCallPattern
225     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
226 public:
227   ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
228       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
229 
230 private:
231   LogicalResult
232   matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
233                   ConversionPatternRewriter &rewriter) const override;
234 };
235 
236 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
237 /// call. Currently it supports CUDA and ROCm (HIP).
238 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
239     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
240 public:
241   ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
242       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
243 
244 private:
245   LogicalResult
246   matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
247                   ConversionPatternRewriter &rewriter) const override;
248 };
249 
250 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
251 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
252 ///
253 /// In essence, a gpu.launch_func operations gets compiled into the following
254 /// sequence of runtime calls:
255 ///
256 /// * moduleLoad        -- loads the module given the cubin / hsaco data
257 /// * moduleGetFunction -- gets a handle to the actual kernel function
258 /// * getStreamHelper   -- initializes a new compute stream on GPU
259 /// * launchKernel      -- launches the kernel on a stream
260 /// * streamSynchronize -- waits for operations on the stream to finish
261 ///
262 /// Intermediate data structures are allocated on the stack.
263 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
264     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
265 public:
266   ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
267                                              StringRef gpuBinaryAnnotation)
268       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
269         gpuBinaryAnnotation(gpuBinaryAnnotation) {}
270 
271 private:
272   Value generateParamsArray(gpu::LaunchFuncOp launchOp,
273                             ArrayRef<Value> operands, OpBuilder &builder) const;
274   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
275                                    Location loc, OpBuilder &builder) const;
276 
277   LogicalResult
278   matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
279                   ConversionPatternRewriter &rewriter) const override;
280 
281   llvm::SmallString<32> gpuBinaryAnnotation;
282 };
283 
284 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
285   using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
286 
287   LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
288                                 PatternRewriter &rewriter) const override {
289     // GPU kernel modules are no longer necessary since we have a global
290     // constant with the CUBIN, or HSACO data.
291     rewriter.eraseOp(op);
292     return success();
293   }
294 };
295 
296 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
297 /// call. Currently it supports CUDA and ROCm (HIP).
298 class ConvertMemcpyOpToGpuRuntimeCallPattern
299     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
300 public:
301   ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
302       : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
303 
304 private:
305   LogicalResult
306   matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
307                   ConversionPatternRewriter &rewriter) const override;
308 };
309 } // namespace
310 
311 void GpuToLLVMConversionPass::runOnOperation() {
312   LLVMTypeConverter converter(&getContext());
313   RewritePatternSet patterns(&getContext());
314   LLVMConversionTarget target(getContext());
315 
316   target.addIllegalDialect<gpu::GPUDialect>();
317 
318   populateVectorToLLVMConversionPatterns(converter, patterns);
319   populateStdToLLVMConversionPatterns(converter, patterns);
320   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
321                                                     target);
322 
323   converter.addConversion(
324       [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
325         return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
326       });
327   patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
328                ConvertDeallocOpToGpuRuntimeCallPattern,
329                ConvertHostRegisterOpToGpuRuntimeCallPattern,
330                ConvertMemcpyOpToGpuRuntimeCallPattern,
331                ConvertWaitAsyncOpToGpuRuntimeCallPattern,
332                ConvertWaitOpToGpuRuntimeCallPattern,
333                ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
334   patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
335                                                            gpuBinaryAnnotation);
336   patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
337 
338   if (failed(
339           applyPartialConversion(getOperation(), target, std::move(patterns))))
340     signalPassFailure();
341 }
342 
343 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
344                                          ArrayRef<Value> arguments) const {
345   auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
346   auto function = [&] {
347     if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
348       return function;
349     return OpBuilder::atBlockEnd(module.getBody())
350         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
351   }();
352   return builder.create<LLVM::CallOp>(
353       loc, const_cast<LLVM::LLVMFunctionType &>(functionType).getReturnType(),
354       builder.getSymbolRefAttr(function), arguments);
355 }
356 
357 // Returns whether all operands are of LLVM type.
358 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
359                                      ConversionPatternRewriter &rewriter) {
360   if (!llvm::all_of(operands, [](Value value) {
361         return LLVM::isCompatibleType(value.getType());
362       }))
363     return rewriter.notifyMatchFailure(
364         op, "Cannot convert if operands aren't of LLVM type.");
365   return success();
366 }
367 
368 static LogicalResult
369 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
370                          gpu::AsyncOpInterface op) {
371   if (op.getAsyncDependencies().size() != 1)
372     return rewriter.notifyMatchFailure(
373         op, "Can only convert with exactly one async dependency.");
374 
375   if (!op.getAsyncToken())
376     return rewriter.notifyMatchFailure(op, "Can convert only async version.");
377 
378   return success();
379 }
380 
381 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
382     gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
383     ConversionPatternRewriter &rewriter) const {
384   auto *op = hostRegisterOp.getOperation();
385   if (failed(areAllLLVMTypes(op, operands, rewriter)))
386     return failure();
387 
388   Location loc = op->getLoc();
389 
390   auto memRefType = hostRegisterOp.value().getType();
391   auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
392   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
393 
394   auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(),
395                                                        operands, rewriter);
396   arguments.push_back(elementSize);
397   hostRegisterCallBuilder.create(loc, rewriter, arguments);
398 
399   rewriter.eraseOp(op);
400   return success();
401 }
402 
403 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
404     gpu::AllocOp allocOp, ArrayRef<Value> operands,
405     ConversionPatternRewriter &rewriter) const {
406   MemRefType memRefType = allocOp.getType();
407 
408   if (failed(areAllLLVMTypes(allocOp, operands, rewriter)) ||
409       !isConvertibleAndHasIdentityMaps(memRefType) ||
410       failed(isAsyncWithOneDependency(rewriter, allocOp)))
411     return failure();
412 
413   auto loc = allocOp.getLoc();
414   auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary());
415 
416   // Get shape of the memref as values: static sizes are constant
417   // values and dynamic sizes are passed to 'alloc' as operands.
418   SmallVector<Value, 4> shape;
419   SmallVector<Value, 4> strides;
420   Value sizeBytes;
421   getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter,
422                            shape, strides, sizeBytes);
423 
424   // Allocate the underlying buffer and store a pointer to it in the MemRef
425   // descriptor.
426   Type elementPtrType = this->getElementPtrType(memRefType);
427   auto stream = adaptor.asyncDependencies().front();
428   Value allocatedPtr =
429       allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
430   allocatedPtr =
431       rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
432 
433   // No alignment.
434   Value alignedPtr = allocatedPtr;
435 
436   // Create the MemRef descriptor.
437   auto memRefDescriptor = this->createMemRefDescriptor(
438       loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
439 
440   rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
441 
442   return success();
443 }
444 
445 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
446     gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
447     ConversionPatternRewriter &rewriter) const {
448   if (failed(areAllLLVMTypes(deallocOp, operands, rewriter)) ||
449       failed(isAsyncWithOneDependency(rewriter, deallocOp)))
450     return failure();
451 
452   Location loc = deallocOp.getLoc();
453 
454   auto adaptor =
455       gpu::DeallocOpAdaptor(operands, deallocOp->getAttrDictionary());
456   Value pointer =
457       MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
458   auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
459   Value stream = adaptor.asyncDependencies().front();
460   deallocCallBuilder.create(loc, rewriter, {casted, stream});
461 
462   rewriter.replaceOp(deallocOp, {stream});
463   return success();
464 }
465 
466 static bool isGpuAsyncTokenType(Value value) {
467   return value.getType().isa<gpu::AsyncTokenType>();
468 }
469 
470 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
471 // !gpu.async.token are lowered to stream within the async.execute region, but
472 // are passed as events between them. For each !gpu.async.token operand, we
473 // create an event and record it on the stream.
474 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
475     async::YieldOp yieldOp, ArrayRef<Value> operands,
476     ConversionPatternRewriter &rewriter) const {
477   if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType))
478     return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
479 
480   Location loc = yieldOp.getLoc();
481   SmallVector<Value, 4> newOperands(operands.begin(), operands.end());
482   llvm::SmallDenseSet<Value> streams;
483   for (auto &operand : yieldOp->getOpOperands()) {
484     if (!isGpuAsyncTokenType(operand.get()))
485       continue;
486     auto idx = operand.getOperandNumber();
487     auto stream = operands[idx];
488     auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
489     eventRecordCallBuilder.create(loc, rewriter, {event, stream});
490     newOperands[idx] = event;
491     streams.insert(stream);
492   }
493   for (auto stream : streams)
494     streamDestroyCallBuilder.create(loc, rewriter, {stream});
495 
496   rewriter.updateRootInPlace(yieldOp,
497                              [&] { yieldOp->setOperands(newOperands); });
498   return success();
499 }
500 
501 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
502 static bool isDefinedByCallTo(Value value, StringRef functionName) {
503   assert(value.getType().isa<LLVM::LLVMPointerType>());
504   if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
505     return defOp.callee()->equals(functionName);
506   return false;
507 }
508 
509 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
510 // with the stream/event operands. The operands are destroyed. That is, it
511 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
512 // runtime error. Eventually, we should guarantee this property.
513 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
514     gpu::WaitOp waitOp, ArrayRef<Value> operands,
515     ConversionPatternRewriter &rewriter) const {
516   if (waitOp.asyncToken())
517     return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
518 
519   Location loc = waitOp.getLoc();
520 
521   for (auto operand : operands) {
522     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
523       // The converted operand's definition created a stream.
524       streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
525       streamDestroyCallBuilder.create(loc, rewriter, {operand});
526     } else {
527       // Otherwise the converted operand is an event. This assumes that we use
528       // events in control flow code as well.
529       eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
530       eventDestroyCallBuilder.create(loc, rewriter, {operand});
531     }
532   }
533 
534   rewriter.eraseOp(waitOp);
535   return success();
536 }
537 
538 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
539 // stream that is synchronized with stream/event operands. The operands are
540 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
541 // Otherwise we will get a runtime error. Eventually, we should guarantee this
542 // property.
543 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
544     gpu::WaitOp waitOp, ArrayRef<Value> operands,
545     ConversionPatternRewriter &rewriter) const {
546   if (!waitOp.asyncToken())
547     return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
548 
549   Location loc = waitOp.getLoc();
550 
551   auto insertionPoint = rewriter.saveInsertionPoint();
552   SmallVector<Value, 1> events;
553   for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) {
554     auto operand = std::get<1>(pair);
555     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
556       // The converted operand's definition created a stream. Insert an event
557       // into the stream just after the last use of the original token operand.
558       auto *defOp = std::get<0>(pair).getDefiningOp();
559       rewriter.setInsertionPointAfter(defOp);
560       auto event =
561           eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
562       eventRecordCallBuilder.create(loc, rewriter, {event, operand});
563       events.push_back(event);
564     } else {
565       // Otherwise the converted operand is an event. This assumes that we use
566       // events in control flow code as well.
567       events.push_back(operand);
568     }
569   }
570   rewriter.restoreInsertionPoint(insertionPoint);
571   auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
572   for (auto event : events)
573     streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
574   for (auto event : events)
575     eventDestroyCallBuilder.create(loc, rewriter, {event});
576   rewriter.replaceOp(waitOp, {stream});
577 
578   return success();
579 }
580 
581 // Creates a struct containing all kernel parameters on the stack and returns
582 // an array of type-erased pointers to the fields of the struct. The array can
583 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
584 // The generated code is essentially as follows:
585 //
586 // %struct = alloca(sizeof(struct { Parameters... }))
587 // %array = alloca(NumParameters * sizeof(void *))
588 // for (i : [0, NumParameters))
589 //   %fieldPtr = llvm.getelementptr %struct[0, i]
590 //   llvm.store parameters[i], %fieldPtr
591 //   %elementPtr = llvm.getelementptr %array[i]
592 //   llvm.store %fieldPtr, %elementPtr
593 // return %array
594 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
595     gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
596     OpBuilder &builder) const {
597   auto loc = launchOp.getLoc();
598   auto numKernelOperands = launchOp.getNumKernelOperands();
599   auto arguments = getTypeConverter()->promoteOperands(
600       loc, launchOp.getOperands().take_back(numKernelOperands),
601       operands.take_back(numKernelOperands), builder);
602   auto numArguments = arguments.size();
603   SmallVector<Type, 4> argumentTypes;
604   argumentTypes.reserve(numArguments);
605   for (auto argument : arguments)
606     argumentTypes.push_back(argument.getType());
607   auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
608                                                            argumentTypes);
609   auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
610                                               builder.getI32IntegerAttr(1));
611   auto structPtr = builder.create<LLVM::AllocaOp>(
612       loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
613   auto arraySize = builder.create<LLVM::ConstantOp>(
614       loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
615   auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
616                                                  arraySize, /*alignment=*/0);
617   auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
618                                                builder.getI32IntegerAttr(0));
619   for (auto en : llvm::enumerate(arguments)) {
620     auto index = builder.create<LLVM::ConstantOp>(
621         loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
622     auto fieldPtr = builder.create<LLVM::GEPOp>(
623         loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
624         ArrayRef<Value>{zero, index.getResult()});
625     builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
626     auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
627                                                   arrayPtr, index.getResult());
628     auto casted =
629         builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
630     builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
631   }
632   return arrayPtr;
633 }
634 
635 // Generates an LLVM IR dialect global that contains the name of the given
636 // kernel function as a C string, and returns a pointer to its beginning.
637 // The code is essentially:
638 //
639 // llvm.global constant @kernel_name("function_name\00")
640 // func(...) {
641 //   %0 = llvm.addressof @kernel_name
642 //   %1 = llvm.constant (0 : index)
643 //   %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
644 // }
645 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
646     StringRef moduleName, StringRef name, Location loc,
647     OpBuilder &builder) const {
648   // Make sure the trailing zero is included in the constant.
649   std::vector<char> kernelName(name.begin(), name.end());
650   kernelName.push_back('\0');
651 
652   std::string globalName =
653       std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
654   return LLVM::createGlobalString(
655       loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
656       LLVM::Linkage::Internal);
657 }
658 
659 // Emits LLVM IR to launch a kernel function. Expects the module that contains
660 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
661 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
662 //
663 // %0 = call %binarygetter
664 // %1 = call %moduleLoad(%0)
665 // %2 = <see generateKernelNameConstant>
666 // %3 = call %moduleGetFunction(%1, %2)
667 // %4 = call %streamCreate()
668 // %5 = <see generateParamsArray>
669 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
670 // call %streamSynchronize(%4)
671 // call %streamDestroy(%4)
672 // call %moduleUnload(%1)
673 //
674 // If the op is async, the stream corresponds to the (single) async dependency
675 // as well as the async token the op produces.
676 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
677     gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
678     ConversionPatternRewriter &rewriter) const {
679   if (failed(areAllLLVMTypes(launchOp, operands, rewriter)))
680     return failure();
681 
682   if (launchOp.asyncDependencies().size() > 1)
683     return rewriter.notifyMatchFailure(
684         launchOp, "Cannot convert with more than one async dependency.");
685 
686   // Fail when the synchronous version of the op has async dependencies. The
687   // lowering destroys the stream, and we do not want to check that there is no
688   // use of the stream after this op.
689   if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
690     return rewriter.notifyMatchFailure(
691         launchOp, "Cannot convert non-async op with async dependencies.");
692 
693   Location loc = launchOp.getLoc();
694 
695   // Create an LLVM global with CUBIN extracted from the kernel annotation and
696   // obtain a pointer to the first byte in it.
697   auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
698       launchOp, launchOp.getKernelModuleName());
699   assert(kernelModule && "expected a kernel module");
700 
701   auto binaryAttr =
702       kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
703   if (!binaryAttr) {
704     kernelModule.emitOpError()
705         << "missing " << gpuBinaryAnnotation << " attribute";
706     return failure();
707   }
708 
709   SmallString<128> nameBuffer(kernelModule.getName());
710   nameBuffer.append(kGpuBinaryStorageSuffix);
711   Value data =
712       LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
713                                binaryAttr.getValue(), LLVM::Linkage::Internal);
714 
715   auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
716   // Get the function from the module. The name corresponds to the name of
717   // the kernel function.
718   auto kernelName = generateKernelNameConstant(
719       launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, rewriter);
720   auto function = moduleGetFunctionCallBuilder.create(
721       loc, rewriter, {module.getResult(0), kernelName});
722   auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
723                                                 rewriter.getI32IntegerAttr(0));
724   auto adaptor =
725       gpu::LaunchFuncOpAdaptor(operands, launchOp->getAttrDictionary());
726   Value stream =
727       adaptor.asyncDependencies().empty()
728           ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
729           : adaptor.asyncDependencies().front();
730   // Create array of pointers to kernel arguments.
731   auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
732   auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
733   launchKernelCallBuilder.create(loc, rewriter,
734                                  {function.getResult(0), launchOp.gridSizeX(),
735                                   launchOp.gridSizeY(), launchOp.gridSizeZ(),
736                                   launchOp.blockSizeX(), launchOp.blockSizeY(),
737                                   launchOp.blockSizeZ(),
738                                   /*sharedMemBytes=*/zero, stream, kernelParams,
739                                   /*extra=*/nullpointer});
740 
741   if (launchOp.asyncToken()) {
742     // Async launch: make dependent ops use the same stream.
743     rewriter.replaceOp(launchOp, {stream});
744   } else {
745     // Synchronize with host and destroy stream. This must be the stream created
746     // above (with no other uses) because we check that the synchronous version
747     // does not have any async dependencies.
748     streamSynchronizeCallBuilder.create(loc, rewriter, stream);
749     streamDestroyCallBuilder.create(loc, rewriter, stream);
750     rewriter.eraseOp(launchOp);
751   }
752   moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
753 
754   return success();
755 }
756 
757 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
758     gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
759     ConversionPatternRewriter &rewriter) const {
760   auto memRefType = memcpyOp.src().getType().cast<MemRefType>();
761 
762   if (failed(areAllLLVMTypes(memcpyOp, operands, rewriter)) ||
763       !isConvertibleAndHasIdentityMaps(memRefType) ||
764       failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
765     return failure();
766 
767   auto loc = memcpyOp.getLoc();
768   auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary());
769 
770   MemRefDescriptor srcDesc(adaptor.src());
771 
772   Value numElements =
773       memRefType.hasStaticShape()
774           ? createIndexConstant(rewriter, loc, memRefType.getNumElements())
775           // For identity layouts (verified above), the number of elements is
776           // stride[0] * size[0].
777           : rewriter.create<LLVM::MulOp>(loc, srcDesc.stride(rewriter, loc, 0),
778                                          srcDesc.size(rewriter, loc, 0));
779 
780   Type elementPtrType = getElementPtrType(memRefType);
781   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
782   Value gepPtr = rewriter.create<LLVM::GEPOp>(
783       loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements});
784   auto sizeBytes =
785       rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
786 
787   auto src = rewriter.create<LLVM::BitcastOp>(
788       loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc));
789   auto dst = rewriter.create<LLVM::BitcastOp>(
790       loc, llvmPointerType,
791       MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc));
792 
793   auto stream = adaptor.asyncDependencies().front();
794   memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
795 
796   rewriter.replaceOp(memcpyOp, {stream});
797 
798   return success();
799 }
800 
801 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
802 mlir::createGpuToLLVMConversionPass() {
803   return std::make_unique<GpuToLLVMConversionPass>();
804 }
805