1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17 
18 #include "../PassDetail.h"
19 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
20 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
21 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
22 #include "mlir/Conversion/LLVMCommon/Pattern.h"
23 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
24 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
25 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
26 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
27 #include "mlir/Dialect/Async/IR/Async.h"
28 #include "mlir/Dialect/GPU/GPUDialect.h"
29 #include "mlir/Dialect/GPU/Passes.h"
30 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
31 #include "mlir/IR/Attributes.h"
32 #include "mlir/IR/Builders.h"
33 #include "mlir/IR/BuiltinOps.h"
34 #include "mlir/IR/BuiltinTypes.h"
35 
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/Support/Error.h"
38 #include "llvm/Support/FormatVariadic.h"
39 
40 using namespace mlir;
41 
42 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
43 
44 namespace {
45 
46 class GpuToLLVMConversionPass
47     : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
48 public:
49   GpuToLLVMConversionPass() = default;
50 
51   GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
52       : GpuToLLVMConversionPassBase(other) {}
53 
54   // Run the dialect converter on the module.
55   void runOnOperation() override;
56 
57 private:
58   Option<std::string> gpuBinaryAnnotation{
59       *this, "gpu-binary-annotation",
60       llvm::cl::desc("Annotation attribute string for GPU binary"),
61       llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
62 };
63 
64 struct FunctionCallBuilder {
65   FunctionCallBuilder(StringRef functionName, Type returnType,
66                       ArrayRef<Type> argumentTypes)
67       : functionName(functionName),
68         functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
69   LLVM::CallOp create(Location loc, OpBuilder &builder,
70                       ArrayRef<Value> arguments) const;
71 
72   StringRef functionName;
73   LLVM::LLVMFunctionType functionType;
74 };
75 
76 template <typename OpTy>
77 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
78 public:
79   explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
80       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
81 
82 protected:
83   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
84                        MemRefType type, MemRefDescriptor desc) const {
85     return type.hasStaticShape()
86                ? ConvertToLLVMPattern::createIndexConstant(
87                      rewriter, loc, type.getNumElements())
88                // For identity maps (verified by caller), the number of
89                // elements is stride[0] * size[0].
90                : rewriter.create<LLVM::MulOp>(loc,
91                                               desc.stride(rewriter, loc, 0),
92                                               desc.size(rewriter, loc, 0));
93   }
94 
95   MLIRContext *context = &this->getTypeConverter()->getContext();
96 
97   Type llvmVoidType = LLVM::LLVMVoidType::get(context);
98   Type llvmPointerType =
99       LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
100   Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType);
101   Type llvmInt8Type = IntegerType::get(context, 8);
102   Type llvmInt32Type = IntegerType::get(context, 32);
103   Type llvmInt64Type = IntegerType::get(context, 64);
104   Type llvmIntPtrType = IntegerType::get(
105       context, this->getTypeConverter()->getPointerBitwidth(0));
106 
107   FunctionCallBuilder moduleLoadCallBuilder = {
108       "mgpuModuleLoad",
109       llvmPointerType /* void *module */,
110       {llvmPointerType /* void *cubin */}};
111   FunctionCallBuilder moduleUnloadCallBuilder = {
112       "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
113   FunctionCallBuilder moduleGetFunctionCallBuilder = {
114       "mgpuModuleGetFunction",
115       llvmPointerType /* void *function */,
116       {
117           llvmPointerType, /* void *module */
118           llvmPointerType  /* char *name   */
119       }};
120   FunctionCallBuilder launchKernelCallBuilder = {
121       "mgpuLaunchKernel",
122       llvmVoidType,
123       {
124           llvmPointerType,        /* void* f */
125           llvmIntPtrType,         /* intptr_t gridXDim */
126           llvmIntPtrType,         /* intptr_t gridyDim */
127           llvmIntPtrType,         /* intptr_t gridZDim */
128           llvmIntPtrType,         /* intptr_t blockXDim */
129           llvmIntPtrType,         /* intptr_t blockYDim */
130           llvmIntPtrType,         /* intptr_t blockZDim */
131           llvmInt32Type,          /* unsigned int sharedMemBytes */
132           llvmPointerType,        /* void *hstream */
133           llvmPointerPointerType, /* void **kernelParams */
134           llvmPointerPointerType  /* void **extra */
135       }};
136   FunctionCallBuilder streamCreateCallBuilder = {
137       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
138   FunctionCallBuilder streamDestroyCallBuilder = {
139       "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
140   FunctionCallBuilder streamSynchronizeCallBuilder = {
141       "mgpuStreamSynchronize",
142       llvmVoidType,
143       {llvmPointerType /* void *stream */}};
144   FunctionCallBuilder streamWaitEventCallBuilder = {
145       "mgpuStreamWaitEvent",
146       llvmVoidType,
147       {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
148   FunctionCallBuilder eventCreateCallBuilder = {
149       "mgpuEventCreate", llvmPointerType /* void *event */, {}};
150   FunctionCallBuilder eventDestroyCallBuilder = {
151       "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
152   FunctionCallBuilder eventSynchronizeCallBuilder = {
153       "mgpuEventSynchronize",
154       llvmVoidType,
155       {llvmPointerType /* void *event */}};
156   FunctionCallBuilder eventRecordCallBuilder = {
157       "mgpuEventRecord",
158       llvmVoidType,
159       {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
160   FunctionCallBuilder hostRegisterCallBuilder = {
161       "mgpuMemHostRegisterMemRef",
162       llvmVoidType,
163       {llvmIntPtrType /* intptr_t rank */,
164        llvmPointerType /* void *memrefDesc */,
165        llvmIntPtrType /* intptr_t elementSizeBytes */}};
166   FunctionCallBuilder allocCallBuilder = {
167       "mgpuMemAlloc",
168       llvmPointerType /* void * */,
169       {llvmIntPtrType /* intptr_t sizeBytes */,
170        llvmPointerType /* void *stream */}};
171   FunctionCallBuilder deallocCallBuilder = {
172       "mgpuMemFree",
173       llvmVoidType,
174       {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
175   FunctionCallBuilder memcpyCallBuilder = {
176       "mgpuMemcpy",
177       llvmVoidType,
178       {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
179        llvmIntPtrType /* intptr_t sizeBytes */,
180        llvmPointerType /* void *stream */}};
181   FunctionCallBuilder memsetCallBuilder = {
182       "mgpuMemset32",
183       llvmVoidType,
184       {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
185        llvmIntPtrType /* intptr_t sizeBytes */,
186        llvmPointerType /* void *stream */}};
187 };
188 
189 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
190 /// call. Currently it supports CUDA and ROCm (HIP).
191 class ConvertHostRegisterOpToGpuRuntimeCallPattern
192     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
193 public:
194   ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
195       : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
196 
197 private:
198   LogicalResult
199   matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
200                   ConversionPatternRewriter &rewriter) const override;
201 };
202 
203 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
204 /// call. Currently it supports CUDA and ROCm (HIP).
205 class ConvertAllocOpToGpuRuntimeCallPattern
206     : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
207 public:
208   ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
209       : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
210 
211 private:
212   LogicalResult
213   matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
214                   ConversionPatternRewriter &rewriter) const override;
215 };
216 
217 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
218 /// call. Currently it supports CUDA and ROCm (HIP).
219 class ConvertDeallocOpToGpuRuntimeCallPattern
220     : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
221 public:
222   ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
223       : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
224 
225 private:
226   LogicalResult
227   matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
228                   ConversionPatternRewriter &rewriter) const override;
229 };
230 
231 class ConvertAsyncYieldToGpuRuntimeCallPattern
232     : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
233 public:
234   ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
235       : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
236 
237 private:
238   LogicalResult
239   matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
240                   ConversionPatternRewriter &rewriter) const override;
241 };
242 
243 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
244 /// call. Currently it supports CUDA and ROCm (HIP).
245 class ConvertWaitOpToGpuRuntimeCallPattern
246     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
247 public:
248   ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
249       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
250 
251 private:
252   LogicalResult
253   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
254                   ConversionPatternRewriter &rewriter) const override;
255 };
256 
257 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
258 /// call. Currently it supports CUDA and ROCm (HIP).
259 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
260     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
261 public:
262   ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
263       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
264 
265 private:
266   LogicalResult
267   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
268                   ConversionPatternRewriter &rewriter) const override;
269 };
270 
271 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
272 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
273 ///
274 /// In essence, a gpu.launch_func operations gets compiled into the following
275 /// sequence of runtime calls:
276 ///
277 /// * moduleLoad        -- loads the module given the cubin / hsaco data
278 /// * moduleGetFunction -- gets a handle to the actual kernel function
279 /// * getStreamHelper   -- initializes a new compute stream on GPU
280 /// * launchKernel      -- launches the kernel on a stream
281 /// * streamSynchronize -- waits for operations on the stream to finish
282 ///
283 /// Intermediate data structures are allocated on the stack.
284 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
285     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
286 public:
287   ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
288                                              StringRef gpuBinaryAnnotation)
289       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
290         gpuBinaryAnnotation(gpuBinaryAnnotation) {}
291 
292 private:
293   Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
294                             OpBuilder &builder) const;
295   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
296                                    Location loc, OpBuilder &builder) const;
297 
298   LogicalResult
299   matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
300                   ConversionPatternRewriter &rewriter) const override;
301 
302   llvm::SmallString<32> gpuBinaryAnnotation;
303 };
304 
305 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
306   using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
307 
308   LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
309                                 PatternRewriter &rewriter) const override {
310     // GPU kernel modules are no longer necessary since we have a global
311     // constant with the CUBIN, or HSACO data.
312     rewriter.eraseOp(op);
313     return success();
314   }
315 };
316 
317 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
318 /// call. Currently it supports CUDA and ROCm (HIP).
319 class ConvertMemcpyOpToGpuRuntimeCallPattern
320     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
321 public:
322   ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
323       : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
324 
325 private:
326   LogicalResult
327   matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
328                   ConversionPatternRewriter &rewriter) const override;
329 };
330 
331 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
332 /// call. Currently it supports CUDA and ROCm (HIP).
333 class ConvertMemsetOpToGpuRuntimeCallPattern
334     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
335 public:
336   ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
337       : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
338 
339 private:
340   LogicalResult
341   matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
342                   ConversionPatternRewriter &rewriter) const override;
343 };
344 } // namespace
345 
346 void GpuToLLVMConversionPass::runOnOperation() {
347   LLVMTypeConverter converter(&getContext());
348   RewritePatternSet patterns(&getContext());
349   LLVMConversionTarget target(getContext());
350 
351   target.addIllegalDialect<gpu::GPUDialect>();
352 
353   mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
354   populateVectorToLLVMConversionPatterns(converter, patterns);
355   populateMemRefToLLVMConversionPatterns(converter, patterns);
356   populateStdToLLVMConversionPatterns(converter, patterns);
357   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
358                                                     target);
359   populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
360 
361   if (failed(
362           applyPartialConversion(getOperation(), target, std::move(patterns))))
363     signalPassFailure();
364 }
365 
366 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
367                                          ArrayRef<Value> arguments) const {
368   auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
369   auto function = [&] {
370     if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
371       return function;
372     return OpBuilder::atBlockEnd(module.getBody())
373         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
374   }();
375   return builder.create<LLVM::CallOp>(loc, function, arguments);
376 }
377 
378 // Returns whether all operands are of LLVM type.
379 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
380                                      ConversionPatternRewriter &rewriter) {
381   if (!llvm::all_of(operands, [](Value value) {
382         return LLVM::isCompatibleType(value.getType());
383       }))
384     return rewriter.notifyMatchFailure(
385         op, "Cannot convert if operands aren't of LLVM type.");
386   return success();
387 }
388 
389 static LogicalResult
390 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
391                          gpu::AsyncOpInterface op) {
392   if (op.getAsyncDependencies().size() != 1)
393     return rewriter.notifyMatchFailure(
394         op, "Can only convert with exactly one async dependency.");
395 
396   if (!op.getAsyncToken())
397     return rewriter.notifyMatchFailure(op, "Can convert only async version.");
398 
399   return success();
400 }
401 
402 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
403     gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
404     ConversionPatternRewriter &rewriter) const {
405   auto *op = hostRegisterOp.getOperation();
406   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
407     return failure();
408 
409   Location loc = op->getLoc();
410 
411   auto memRefType = hostRegisterOp.value().getType();
412   auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
413   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
414 
415   auto arguments = getTypeConverter()->promoteOperands(
416       loc, op->getOperands(), adaptor.getOperands(), rewriter);
417   arguments.push_back(elementSize);
418   hostRegisterCallBuilder.create(loc, rewriter, arguments);
419 
420   rewriter.eraseOp(op);
421   return success();
422 }
423 
424 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
425     gpu::AllocOp allocOp, OpAdaptor adaptor,
426     ConversionPatternRewriter &rewriter) const {
427   MemRefType memRefType = allocOp.getType();
428 
429   if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
430       !isConvertibleAndHasIdentityMaps(memRefType) ||
431       failed(isAsyncWithOneDependency(rewriter, allocOp)))
432     return failure();
433 
434   auto loc = allocOp.getLoc();
435 
436   // Get shape of the memref as values: static sizes are constant
437   // values and dynamic sizes are passed to 'alloc' as operands.
438   SmallVector<Value, 4> shape;
439   SmallVector<Value, 4> strides;
440   Value sizeBytes;
441   getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter,
442                            shape, strides, sizeBytes);
443 
444   // Allocate the underlying buffer and store a pointer to it in the MemRef
445   // descriptor.
446   Type elementPtrType = this->getElementPtrType(memRefType);
447   auto stream = adaptor.asyncDependencies().front();
448   Value allocatedPtr =
449       allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
450   allocatedPtr =
451       rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
452 
453   // No alignment.
454   Value alignedPtr = allocatedPtr;
455 
456   // Create the MemRef descriptor.
457   auto memRefDescriptor = this->createMemRefDescriptor(
458       loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
459 
460   rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
461 
462   return success();
463 }
464 
465 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
466     gpu::DeallocOp deallocOp, OpAdaptor adaptor,
467     ConversionPatternRewriter &rewriter) const {
468   if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
469       failed(isAsyncWithOneDependency(rewriter, deallocOp)))
470     return failure();
471 
472   Location loc = deallocOp.getLoc();
473 
474   Value pointer =
475       MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
476   auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
477   Value stream = adaptor.asyncDependencies().front();
478   deallocCallBuilder.create(loc, rewriter, {casted, stream});
479 
480   rewriter.replaceOp(deallocOp, {stream});
481   return success();
482 }
483 
484 static bool isGpuAsyncTokenType(Value value) {
485   return value.getType().isa<gpu::AsyncTokenType>();
486 }
487 
488 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
489 // !gpu.async.token are lowered to stream within the async.execute region, but
490 // are passed as events between them. For each !gpu.async.token operand, we
491 // create an event and record it on the stream.
492 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
493     async::YieldOp yieldOp, OpAdaptor adaptor,
494     ConversionPatternRewriter &rewriter) const {
495   if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType))
496     return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
497 
498   Location loc = yieldOp.getLoc();
499   SmallVector<Value, 4> newOperands(adaptor.getOperands());
500   llvm::SmallDenseSet<Value> streams;
501   for (auto &operand : yieldOp->getOpOperands()) {
502     if (!isGpuAsyncTokenType(operand.get()))
503       continue;
504     auto idx = operand.getOperandNumber();
505     auto stream = adaptor.getOperands()[idx];
506     auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
507     eventRecordCallBuilder.create(loc, rewriter, {event, stream});
508     newOperands[idx] = event;
509     streams.insert(stream);
510   }
511   for (auto stream : streams)
512     streamDestroyCallBuilder.create(loc, rewriter, {stream});
513 
514   rewriter.updateRootInPlace(yieldOp,
515                              [&] { yieldOp->setOperands(newOperands); });
516   return success();
517 }
518 
519 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
520 static bool isDefinedByCallTo(Value value, StringRef functionName) {
521   assert(value.getType().isa<LLVM::LLVMPointerType>());
522   if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
523     return defOp.getCallee()->equals(functionName);
524   return false;
525 }
526 
527 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
528 // with the stream/event operands. The operands are destroyed. That is, it
529 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
530 // runtime error. Eventually, we should guarantee this property.
531 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
532     gpu::WaitOp waitOp, OpAdaptor adaptor,
533     ConversionPatternRewriter &rewriter) const {
534   if (waitOp.asyncToken())
535     return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
536 
537   Location loc = waitOp.getLoc();
538 
539   for (auto operand : adaptor.getOperands()) {
540     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
541       // The converted operand's definition created a stream.
542       streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
543       streamDestroyCallBuilder.create(loc, rewriter, {operand});
544     } else {
545       // Otherwise the converted operand is an event. This assumes that we use
546       // events in control flow code as well.
547       eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
548       eventDestroyCallBuilder.create(loc, rewriter, {operand});
549     }
550   }
551 
552   rewriter.eraseOp(waitOp);
553   return success();
554 }
555 
556 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
557 // stream that is synchronized with stream/event operands. The operands are
558 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
559 // Otherwise we will get a runtime error. Eventually, we should guarantee this
560 // property.
561 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
562     gpu::WaitOp waitOp, OpAdaptor adaptor,
563     ConversionPatternRewriter &rewriter) const {
564   if (!waitOp.asyncToken())
565     return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
566 
567   Location loc = waitOp.getLoc();
568 
569   auto insertionPoint = rewriter.saveInsertionPoint();
570   SmallVector<Value, 1> events;
571   for (auto pair :
572        llvm::zip(waitOp.asyncDependencies(), adaptor.getOperands())) {
573     auto operand = std::get<1>(pair);
574     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
575       // The converted operand's definition created a stream. Insert an event
576       // into the stream just after the last use of the original token operand.
577       auto *defOp = std::get<0>(pair).getDefiningOp();
578       rewriter.setInsertionPointAfter(defOp);
579       auto event =
580           eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
581       eventRecordCallBuilder.create(loc, rewriter, {event, operand});
582       events.push_back(event);
583     } else {
584       // Otherwise the converted operand is an event. This assumes that we use
585       // events in control flow code as well.
586       events.push_back(operand);
587     }
588   }
589   rewriter.restoreInsertionPoint(insertionPoint);
590   auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
591   for (auto event : events)
592     streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
593   for (auto event : events)
594     eventDestroyCallBuilder.create(loc, rewriter, {event});
595   rewriter.replaceOp(waitOp, {stream});
596 
597   return success();
598 }
599 
600 // Creates a struct containing all kernel parameters on the stack and returns
601 // an array of type-erased pointers to the fields of the struct. The array can
602 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
603 // The generated code is essentially as follows:
604 //
605 // %struct = alloca(sizeof(struct { Parameters... }))
606 // %array = alloca(NumParameters * sizeof(void *))
607 // for (i : [0, NumParameters))
608 //   %fieldPtr = llvm.getelementptr %struct[0, i]
609 //   llvm.store parameters[i], %fieldPtr
610 //   %elementPtr = llvm.getelementptr %array[i]
611 //   llvm.store %fieldPtr, %elementPtr
612 // return %array
613 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
614     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
615   auto loc = launchOp.getLoc();
616   auto numKernelOperands = launchOp.getNumKernelOperands();
617   auto arguments = getTypeConverter()->promoteOperands(
618       loc, launchOp.getOperands().take_back(numKernelOperands),
619       adaptor.getOperands().take_back(numKernelOperands), builder);
620   auto numArguments = arguments.size();
621   SmallVector<Type, 4> argumentTypes;
622   argumentTypes.reserve(numArguments);
623   for (auto argument : arguments)
624     argumentTypes.push_back(argument.getType());
625   auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
626                                                            argumentTypes);
627   auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
628                                               builder.getI32IntegerAttr(1));
629   auto structPtr = builder.create<LLVM::AllocaOp>(
630       loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
631   auto arraySize = builder.create<LLVM::ConstantOp>(
632       loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
633   auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
634                                                  arraySize, /*alignment=*/0);
635   auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
636                                                builder.getI32IntegerAttr(0));
637   for (auto en : llvm::enumerate(arguments)) {
638     auto index = builder.create<LLVM::ConstantOp>(
639         loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
640     auto fieldPtr = builder.create<LLVM::GEPOp>(
641         loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
642         ArrayRef<Value>{zero, index.getResult()});
643     builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
644     auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
645                                                   arrayPtr, index.getResult());
646     auto casted =
647         builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
648     builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
649   }
650   return arrayPtr;
651 }
652 
653 // Generates an LLVM IR dialect global that contains the name of the given
654 // kernel function as a C string, and returns a pointer to its beginning.
655 // The code is essentially:
656 //
657 // llvm.global constant @kernel_name("function_name\00")
658 // func(...) {
659 //   %0 = llvm.addressof @kernel_name
660 //   %1 = llvm.constant (0 : index)
661 //   %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
662 // }
663 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
664     StringRef moduleName, StringRef name, Location loc,
665     OpBuilder &builder) const {
666   // Make sure the trailing zero is included in the constant.
667   std::vector<char> kernelName(name.begin(), name.end());
668   kernelName.push_back('\0');
669 
670   std::string globalName =
671       std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
672   return LLVM::createGlobalString(
673       loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
674       LLVM::Linkage::Internal);
675 }
676 
677 // Emits LLVM IR to launch a kernel function. Expects the module that contains
678 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
679 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
680 //
681 // %0 = call %binarygetter
682 // %1 = call %moduleLoad(%0)
683 // %2 = <see generateKernelNameConstant>
684 // %3 = call %moduleGetFunction(%1, %2)
685 // %4 = call %streamCreate()
686 // %5 = <see generateParamsArray>
687 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
688 // call %streamSynchronize(%4)
689 // call %streamDestroy(%4)
690 // call %moduleUnload(%1)
691 //
692 // If the op is async, the stream corresponds to the (single) async dependency
693 // as well as the async token the op produces.
694 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
695     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
696     ConversionPatternRewriter &rewriter) const {
697   if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
698     return failure();
699 
700   if (launchOp.asyncDependencies().size() > 1)
701     return rewriter.notifyMatchFailure(
702         launchOp, "Cannot convert with more than one async dependency.");
703 
704   // Fail when the synchronous version of the op has async dependencies. The
705   // lowering destroys the stream, and we do not want to check that there is no
706   // use of the stream after this op.
707   if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
708     return rewriter.notifyMatchFailure(
709         launchOp, "Cannot convert non-async op with async dependencies.");
710 
711   Location loc = launchOp.getLoc();
712 
713   // Create an LLVM global with CUBIN extracted from the kernel annotation and
714   // obtain a pointer to the first byte in it.
715   auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
716       launchOp, launchOp.getKernelModuleName());
717   assert(kernelModule && "expected a kernel module");
718 
719   auto binaryAttr =
720       kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
721   if (!binaryAttr) {
722     kernelModule.emitOpError()
723         << "missing " << gpuBinaryAnnotation << " attribute";
724     return failure();
725   }
726 
727   SmallString<128> nameBuffer(kernelModule.getName());
728   nameBuffer.append(kGpuBinaryStorageSuffix);
729   Value data =
730       LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
731                                binaryAttr.getValue(), LLVM::Linkage::Internal);
732 
733   auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
734   // Get the function from the module. The name corresponds to the name of
735   // the kernel function.
736   auto kernelName = generateKernelNameConstant(
737       launchOp.getKernelModuleName().getValue(),
738       launchOp.getKernelName().getValue(), loc, rewriter);
739   auto function = moduleGetFunctionCallBuilder.create(
740       loc, rewriter, {module.getResult(0), kernelName});
741   auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
742                                                 rewriter.getI32IntegerAttr(0));
743   Value stream =
744       adaptor.asyncDependencies().empty()
745           ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
746           : adaptor.asyncDependencies().front();
747   // Create array of pointers to kernel arguments.
748   auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
749   auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
750   Value dynamicSharedMemorySize = launchOp.dynamicSharedMemorySize()
751                                       ? launchOp.dynamicSharedMemorySize()
752                                       : zero;
753   launchKernelCallBuilder.create(
754       loc, rewriter,
755       {function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(),
756        adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(),
757        adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
758        /*extra=*/nullpointer});
759 
760   if (launchOp.asyncToken()) {
761     // Async launch: make dependent ops use the same stream.
762     rewriter.replaceOp(launchOp, {stream});
763   } else {
764     // Synchronize with host and destroy stream. This must be the stream created
765     // above (with no other uses) because we check that the synchronous version
766     // does not have any async dependencies.
767     streamSynchronizeCallBuilder.create(loc, rewriter, stream);
768     streamDestroyCallBuilder.create(loc, rewriter, stream);
769     rewriter.eraseOp(launchOp);
770   }
771   moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
772 
773   return success();
774 }
775 
776 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
777     gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
778     ConversionPatternRewriter &rewriter) const {
779   auto memRefType = memcpyOp.src().getType().cast<MemRefType>();
780 
781   if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
782       !isConvertibleAndHasIdentityMaps(memRefType) ||
783       failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
784     return failure();
785 
786   auto loc = memcpyOp.getLoc();
787 
788   MemRefDescriptor srcDesc(adaptor.src());
789   Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
790 
791   Type elementPtrType = getElementPtrType(memRefType);
792   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
793   Value gepPtr = rewriter.create<LLVM::GEPOp>(
794       loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements});
795   auto sizeBytes =
796       rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
797 
798   auto src = rewriter.create<LLVM::BitcastOp>(
799       loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc));
800   auto dst = rewriter.create<LLVM::BitcastOp>(
801       loc, llvmPointerType,
802       MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc));
803 
804   auto stream = adaptor.asyncDependencies().front();
805   memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
806 
807   rewriter.replaceOp(memcpyOp, {stream});
808 
809   return success();
810 }
811 
812 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
813     gpu::MemsetOp memsetOp, OpAdaptor adaptor,
814     ConversionPatternRewriter &rewriter) const {
815   auto memRefType = memsetOp.dst().getType().cast<MemRefType>();
816 
817   if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
818       !isConvertibleAndHasIdentityMaps(memRefType) ||
819       failed(isAsyncWithOneDependency(rewriter, memsetOp)))
820     return failure();
821 
822   auto loc = memsetOp.getLoc();
823 
824   Type valueType = adaptor.value().getType();
825   if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
826     return rewriter.notifyMatchFailure(memsetOp,
827                                        "value must be a 32 bit scalar");
828   }
829 
830   MemRefDescriptor dstDesc(adaptor.dst());
831   Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
832 
833   auto value =
834       rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value());
835   auto dst = rewriter.create<LLVM::BitcastOp>(
836       loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc));
837 
838   auto stream = adaptor.asyncDependencies().front();
839   memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
840 
841   rewriter.replaceOp(memsetOp, {stream});
842   return success();
843 }
844 
845 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
846 mlir::createGpuToLLVMConversionPass() {
847   return std::make_unique<GpuToLLVMConversionPass>();
848 }
849 
850 void mlir::populateGpuToLLVMConversionPatterns(
851     LLVMTypeConverter &converter, OwningRewritePatternList &patterns,
852     StringRef gpuBinaryAnnotation) {
853   converter.addConversion(
854       [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
855         return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
856       });
857   patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
858                ConvertDeallocOpToGpuRuntimeCallPattern,
859                ConvertHostRegisterOpToGpuRuntimeCallPattern,
860                ConvertMemcpyOpToGpuRuntimeCallPattern,
861                ConvertMemsetOpToGpuRuntimeCallPattern,
862                ConvertWaitAsyncOpToGpuRuntimeCallPattern,
863                ConvertWaitOpToGpuRuntimeCallPattern,
864                ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
865   patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
866                                                            gpuBinaryAnnotation);
867   patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
868 }
869