1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17 
18 #include "../PassDetail.h"
19 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
20 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
21 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
22 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
24 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
25 #include "mlir/Conversion/LLVMCommon/Pattern.h"
26 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
27 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
28 #include "mlir/Dialect/Async/IR/Async.h"
29 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
30 #include "mlir/Dialect/GPU/Transforms/Passes.h"
31 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
32 #include "mlir/IR/Attributes.h"
33 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/BuiltinOps.h"
35 #include "mlir/IR/BuiltinTypes.h"
36 
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/Support/Error.h"
39 #include "llvm/Support/FormatVariadic.h"
40 
41 using namespace mlir;
42 
43 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
44 
45 namespace {
46 
47 class GpuToLLVMConversionPass
48     : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
49 public:
50   GpuToLLVMConversionPass() = default;
51 
52   GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
53       : GpuToLLVMConversionPassBase(other) {}
54 
55   // Run the dialect converter on the module.
56   void runOnOperation() override;
57 
58 private:
59   Option<std::string> gpuBinaryAnnotation{
60       *this, "gpu-binary-annotation",
61       llvm::cl::desc("Annotation attribute string for GPU binary"),
62       llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
63 };
64 
65 struct FunctionCallBuilder {
66   FunctionCallBuilder(StringRef functionName, Type returnType,
67                       ArrayRef<Type> argumentTypes)
68       : functionName(functionName),
69         functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
70   LLVM::CallOp create(Location loc, OpBuilder &builder,
71                       ArrayRef<Value> arguments) const;
72 
73   StringRef functionName;
74   LLVM::LLVMFunctionType functionType;
75 };
76 
77 template <typename OpTy>
78 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
79 public:
80   explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
81       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
82 
83 protected:
84   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
85                        MemRefType type, MemRefDescriptor desc) const {
86     return type.hasStaticShape()
87                ? ConvertToLLVMPattern::createIndexConstant(
88                      rewriter, loc, type.getNumElements())
89                // For identity maps (verified by caller), the number of
90                // elements is stride[0] * size[0].
91                : rewriter.create<LLVM::MulOp>(loc,
92                                               desc.stride(rewriter, loc, 0),
93                                               desc.size(rewriter, loc, 0));
94   }
95 
96   MLIRContext *context = &this->getTypeConverter()->getContext();
97 
98   Type llvmVoidType = LLVM::LLVMVoidType::get(context);
99   Type llvmPointerType =
100       LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
101   Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType);
102   Type llvmInt8Type = IntegerType::get(context, 8);
103   Type llvmInt32Type = IntegerType::get(context, 32);
104   Type llvmInt64Type = IntegerType::get(context, 64);
105   Type llvmIntPtrType = IntegerType::get(
106       context, this->getTypeConverter()->getPointerBitwidth(0));
107 
108   FunctionCallBuilder moduleLoadCallBuilder = {
109       "mgpuModuleLoad",
110       llvmPointerType /* void *module */,
111       {llvmPointerType /* void *cubin */}};
112   FunctionCallBuilder moduleUnloadCallBuilder = {
113       "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
114   FunctionCallBuilder moduleGetFunctionCallBuilder = {
115       "mgpuModuleGetFunction",
116       llvmPointerType /* void *function */,
117       {
118           llvmPointerType, /* void *module */
119           llvmPointerType  /* char *name   */
120       }};
121   FunctionCallBuilder launchKernelCallBuilder = {
122       "mgpuLaunchKernel",
123       llvmVoidType,
124       {
125           llvmPointerType,        /* void* f */
126           llvmIntPtrType,         /* intptr_t gridXDim */
127           llvmIntPtrType,         /* intptr_t gridyDim */
128           llvmIntPtrType,         /* intptr_t gridZDim */
129           llvmIntPtrType,         /* intptr_t blockXDim */
130           llvmIntPtrType,         /* intptr_t blockYDim */
131           llvmIntPtrType,         /* intptr_t blockZDim */
132           llvmInt32Type,          /* unsigned int sharedMemBytes */
133           llvmPointerType,        /* void *hstream */
134           llvmPointerPointerType, /* void **kernelParams */
135           llvmPointerPointerType  /* void **extra */
136       }};
137   FunctionCallBuilder streamCreateCallBuilder = {
138       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
139   FunctionCallBuilder streamDestroyCallBuilder = {
140       "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
141   FunctionCallBuilder streamSynchronizeCallBuilder = {
142       "mgpuStreamSynchronize",
143       llvmVoidType,
144       {llvmPointerType /* void *stream */}};
145   FunctionCallBuilder streamWaitEventCallBuilder = {
146       "mgpuStreamWaitEvent",
147       llvmVoidType,
148       {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
149   FunctionCallBuilder eventCreateCallBuilder = {
150       "mgpuEventCreate", llvmPointerType /* void *event */, {}};
151   FunctionCallBuilder eventDestroyCallBuilder = {
152       "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
153   FunctionCallBuilder eventSynchronizeCallBuilder = {
154       "mgpuEventSynchronize",
155       llvmVoidType,
156       {llvmPointerType /* void *event */}};
157   FunctionCallBuilder eventRecordCallBuilder = {
158       "mgpuEventRecord",
159       llvmVoidType,
160       {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
161   FunctionCallBuilder hostRegisterCallBuilder = {
162       "mgpuMemHostRegisterMemRef",
163       llvmVoidType,
164       {llvmIntPtrType /* intptr_t rank */,
165        llvmPointerType /* void *memrefDesc */,
166        llvmIntPtrType /* intptr_t elementSizeBytes */}};
167   FunctionCallBuilder allocCallBuilder = {
168       "mgpuMemAlloc",
169       llvmPointerType /* void * */,
170       {llvmIntPtrType /* intptr_t sizeBytes */,
171        llvmPointerType /* void *stream */}};
172   FunctionCallBuilder deallocCallBuilder = {
173       "mgpuMemFree",
174       llvmVoidType,
175       {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
176   FunctionCallBuilder memcpyCallBuilder = {
177       "mgpuMemcpy",
178       llvmVoidType,
179       {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
180        llvmIntPtrType /* intptr_t sizeBytes */,
181        llvmPointerType /* void *stream */}};
182   FunctionCallBuilder memsetCallBuilder = {
183       "mgpuMemset32",
184       llvmVoidType,
185       {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
186        llvmIntPtrType /* intptr_t sizeBytes */,
187        llvmPointerType /* void *stream */}};
188   FunctionCallBuilder setDefaultDeviceCallBuilder = {
189       "mgpuSetDefaultDevice",
190       llvmVoidType,
191       {llvmInt32Type /* uint32_t devIndex */}};
192 };
193 
194 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
195 /// call. Currently it supports CUDA and ROCm (HIP).
196 class ConvertHostRegisterOpToGpuRuntimeCallPattern
197     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
198 public:
199   ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
200       : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
201 
202 private:
203   LogicalResult
204   matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
205                   ConversionPatternRewriter &rewriter) const override;
206 };
207 
208 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
209 /// call. Currently it supports CUDA and ROCm (HIP).
210 class ConvertAllocOpToGpuRuntimeCallPattern
211     : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
212 public:
213   ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
214       : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
215 
216 private:
217   LogicalResult
218   matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
219                   ConversionPatternRewriter &rewriter) const override;
220 };
221 
222 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
223 /// call. Currently it supports CUDA and ROCm (HIP).
224 class ConvertDeallocOpToGpuRuntimeCallPattern
225     : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
226 public:
227   ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
228       : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
229 
230 private:
231   LogicalResult
232   matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
233                   ConversionPatternRewriter &rewriter) const override;
234 };
235 
236 class ConvertAsyncYieldToGpuRuntimeCallPattern
237     : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
238 public:
239   ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
240       : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
241 
242 private:
243   LogicalResult
244   matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
245                   ConversionPatternRewriter &rewriter) const override;
246 };
247 
248 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
249 /// call. Currently it supports CUDA and ROCm (HIP).
250 class ConvertWaitOpToGpuRuntimeCallPattern
251     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
252 public:
253   ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
254       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
255 
256 private:
257   LogicalResult
258   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
259                   ConversionPatternRewriter &rewriter) const override;
260 };
261 
262 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
263 /// call. Currently it supports CUDA and ROCm (HIP).
264 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
265     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
266 public:
267   ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
268       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
269 
270 private:
271   LogicalResult
272   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
273                   ConversionPatternRewriter &rewriter) const override;
274 };
275 
276 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
277 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
278 ///
279 /// In essence, a gpu.launch_func operations gets compiled into the following
280 /// sequence of runtime calls:
281 ///
282 /// * moduleLoad        -- loads the module given the cubin / hsaco data
283 /// * moduleGetFunction -- gets a handle to the actual kernel function
284 /// * getStreamHelper   -- initializes a new compute stream on GPU
285 /// * launchKernel      -- launches the kernel on a stream
286 /// * streamSynchronize -- waits for operations on the stream to finish
287 ///
288 /// Intermediate data structures are allocated on the stack.
289 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
290     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
291 public:
292   ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
293                                              StringRef gpuBinaryAnnotation)
294       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
295         gpuBinaryAnnotation(gpuBinaryAnnotation) {}
296 
297 private:
298   Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
299                             OpBuilder &builder) const;
300   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
301                                    Location loc, OpBuilder &builder) const;
302 
303   LogicalResult
304   matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
305                   ConversionPatternRewriter &rewriter) const override;
306 
307   llvm::SmallString<32> gpuBinaryAnnotation;
308 };
309 
310 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
311   using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
312 
313   LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
314                                 PatternRewriter &rewriter) const override {
315     // GPU kernel modules are no longer necessary since we have a global
316     // constant with the CUBIN, or HSACO data.
317     rewriter.eraseOp(op);
318     return success();
319   }
320 };
321 
322 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
323 /// call. Currently it supports CUDA and ROCm (HIP).
324 class ConvertMemcpyOpToGpuRuntimeCallPattern
325     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
326 public:
327   ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
328       : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
329 
330 private:
331   LogicalResult
332   matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
333                   ConversionPatternRewriter &rewriter) const override;
334 };
335 
336 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
337 /// call. Currently it supports CUDA and ROCm (HIP).
338 class ConvertMemsetOpToGpuRuntimeCallPattern
339     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
340 public:
341   ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
342       : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
343 
344 private:
345   LogicalResult
346   matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
347                   ConversionPatternRewriter &rewriter) const override;
348 };
349 
350 /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
351 /// Currently supports CUDA and ROCm (HIP)
352 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
353     : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
354 public:
355   ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
356       LLVMTypeConverter &typeConverter)
357       : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
358             typeConverter) {}
359 
360   LogicalResult
361   matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
362                   ConversionPatternRewriter &rewriter) const override;
363 };
364 } // namespace
365 
366 void GpuToLLVMConversionPass::runOnOperation() {
367   LLVMTypeConverter converter(&getContext());
368   RewritePatternSet patterns(&getContext());
369   LLVMConversionTarget target(getContext());
370 
371   target.addIllegalDialect<gpu::GPUDialect>();
372 
373   mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
374   mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
375   populateVectorToLLVMConversionPatterns(converter, patterns);
376   populateMemRefToLLVMConversionPatterns(converter, patterns);
377   populateFuncToLLVMConversionPatterns(converter, patterns);
378   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
379                                                     target);
380   populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
381 
382   if (failed(
383           applyPartialConversion(getOperation(), target, std::move(patterns))))
384     signalPassFailure();
385 }
386 
387 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
388                                          ArrayRef<Value> arguments) const {
389   auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
390   auto function = [&] {
391     if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
392       return function;
393     return OpBuilder::atBlockEnd(module.getBody())
394         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
395   }();
396   return builder.create<LLVM::CallOp>(loc, function, arguments);
397 }
398 
399 // Returns whether all operands are of LLVM type.
400 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
401                                      ConversionPatternRewriter &rewriter) {
402   if (!llvm::all_of(operands, [](Value value) {
403         return LLVM::isCompatibleType(value.getType());
404       }))
405     return rewriter.notifyMatchFailure(
406         op, "Cannot convert if operands aren't of LLVM type.");
407   return success();
408 }
409 
410 static LogicalResult
411 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
412                          gpu::AsyncOpInterface op) {
413   if (op.getAsyncDependencies().size() != 1)
414     return rewriter.notifyMatchFailure(
415         op, "Can only convert with exactly one async dependency.");
416 
417   if (!op.getAsyncToken())
418     return rewriter.notifyMatchFailure(op, "Can convert only async version.");
419 
420   return success();
421 }
422 
423 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
424     gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
425     ConversionPatternRewriter &rewriter) const {
426   auto *op = hostRegisterOp.getOperation();
427   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
428     return failure();
429 
430   Location loc = op->getLoc();
431 
432   auto memRefType = hostRegisterOp.value().getType();
433   auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
434   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
435 
436   auto arguments = getTypeConverter()->promoteOperands(
437       loc, op->getOperands(), adaptor.getOperands(), rewriter);
438   arguments.push_back(elementSize);
439   hostRegisterCallBuilder.create(loc, rewriter, arguments);
440 
441   rewriter.eraseOp(op);
442   return success();
443 }
444 
445 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
446     gpu::AllocOp allocOp, OpAdaptor adaptor,
447     ConversionPatternRewriter &rewriter) const {
448   MemRefType memRefType = allocOp.getType();
449 
450   if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
451       !isConvertibleAndHasIdentityMaps(memRefType) ||
452       failed(isAsyncWithOneDependency(rewriter, allocOp)))
453     return failure();
454 
455   auto loc = allocOp.getLoc();
456 
457   // Get shape of the memref as values: static sizes are constant
458   // values and dynamic sizes are passed to 'alloc' as operands.
459   SmallVector<Value, 4> shape;
460   SmallVector<Value, 4> strides;
461   Value sizeBytes;
462   getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter,
463                            shape, strides, sizeBytes);
464 
465   // Allocate the underlying buffer and store a pointer to it in the MemRef
466   // descriptor.
467   Type elementPtrType = this->getElementPtrType(memRefType);
468   auto stream = adaptor.asyncDependencies().front();
469   Value allocatedPtr =
470       allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
471   allocatedPtr =
472       rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
473 
474   // No alignment.
475   Value alignedPtr = allocatedPtr;
476 
477   // Create the MemRef descriptor.
478   auto memRefDescriptor = this->createMemRefDescriptor(
479       loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
480 
481   rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
482 
483   return success();
484 }
485 
486 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
487     gpu::DeallocOp deallocOp, OpAdaptor adaptor,
488     ConversionPatternRewriter &rewriter) const {
489   if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
490       failed(isAsyncWithOneDependency(rewriter, deallocOp)))
491     return failure();
492 
493   Location loc = deallocOp.getLoc();
494 
495   Value pointer =
496       MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
497   auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
498   Value stream = adaptor.asyncDependencies().front();
499   deallocCallBuilder.create(loc, rewriter, {casted, stream});
500 
501   rewriter.replaceOp(deallocOp, {stream});
502   return success();
503 }
504 
505 static bool isGpuAsyncTokenType(Value value) {
506   return value.getType().isa<gpu::AsyncTokenType>();
507 }
508 
509 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
510 // !gpu.async.token are lowered to stream within the async.execute region, but
511 // are passed as events between them. For each !gpu.async.token operand, we
512 // create an event and record it on the stream.
513 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
514     async::YieldOp yieldOp, OpAdaptor adaptor,
515     ConversionPatternRewriter &rewriter) const {
516   if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType))
517     return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
518 
519   Location loc = yieldOp.getLoc();
520   SmallVector<Value, 4> newOperands(adaptor.getOperands());
521   llvm::SmallDenseSet<Value> streams;
522   for (auto &operand : yieldOp->getOpOperands()) {
523     if (!isGpuAsyncTokenType(operand.get()))
524       continue;
525     auto idx = operand.getOperandNumber();
526     auto stream = adaptor.getOperands()[idx];
527     auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
528     eventRecordCallBuilder.create(loc, rewriter, {event, stream});
529     newOperands[idx] = event;
530     streams.insert(stream);
531   }
532   for (auto stream : streams)
533     streamDestroyCallBuilder.create(loc, rewriter, {stream});
534 
535   rewriter.updateRootInPlace(yieldOp,
536                              [&] { yieldOp->setOperands(newOperands); });
537   return success();
538 }
539 
540 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
541 static bool isDefinedByCallTo(Value value, StringRef functionName) {
542   assert(value.getType().isa<LLVM::LLVMPointerType>());
543   if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
544     return defOp.getCallee()->equals(functionName);
545   return false;
546 }
547 
548 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
549 // with the stream/event operands. The operands are destroyed. That is, it
550 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
551 // runtime error. Eventually, we should guarantee this property.
552 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
553     gpu::WaitOp waitOp, OpAdaptor adaptor,
554     ConversionPatternRewriter &rewriter) const {
555   if (waitOp.asyncToken())
556     return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
557 
558   Location loc = waitOp.getLoc();
559 
560   for (auto operand : adaptor.getOperands()) {
561     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
562       // The converted operand's definition created a stream.
563       streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
564       streamDestroyCallBuilder.create(loc, rewriter, {operand});
565     } else {
566       // Otherwise the converted operand is an event. This assumes that we use
567       // events in control flow code as well.
568       eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
569       eventDestroyCallBuilder.create(loc, rewriter, {operand});
570     }
571   }
572 
573   rewriter.eraseOp(waitOp);
574   return success();
575 }
576 
577 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
578 // stream that is synchronized with stream/event operands. The operands are
579 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
580 // Otherwise we will get a runtime error. Eventually, we should guarantee this
581 // property.
582 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
583     gpu::WaitOp waitOp, OpAdaptor adaptor,
584     ConversionPatternRewriter &rewriter) const {
585   if (!waitOp.asyncToken())
586     return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
587 
588   Location loc = waitOp.getLoc();
589 
590   auto insertionPoint = rewriter.saveInsertionPoint();
591   SmallVector<Value, 1> events;
592   for (auto pair :
593        llvm::zip(waitOp.asyncDependencies(), adaptor.getOperands())) {
594     auto operand = std::get<1>(pair);
595     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
596       // The converted operand's definition created a stream. Insert an event
597       // into the stream just after the last use of the original token operand.
598       auto *defOp = std::get<0>(pair).getDefiningOp();
599       rewriter.setInsertionPointAfter(defOp);
600       auto event =
601           eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
602       eventRecordCallBuilder.create(loc, rewriter, {event, operand});
603       events.push_back(event);
604     } else {
605       // Otherwise the converted operand is an event. This assumes that we use
606       // events in control flow code as well.
607       events.push_back(operand);
608     }
609   }
610   rewriter.restoreInsertionPoint(insertionPoint);
611   auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
612   for (auto event : events)
613     streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
614   for (auto event : events)
615     eventDestroyCallBuilder.create(loc, rewriter, {event});
616   rewriter.replaceOp(waitOp, {stream});
617 
618   return success();
619 }
620 
621 // Creates a struct containing all kernel parameters on the stack and returns
622 // an array of type-erased pointers to the fields of the struct. The array can
623 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
624 // The generated code is essentially as follows:
625 //
626 // %struct = alloca(sizeof(struct { Parameters... }))
627 // %array = alloca(NumParameters * sizeof(void *))
628 // for (i : [0, NumParameters))
629 //   %fieldPtr = llvm.getelementptr %struct[0, i]
630 //   llvm.store parameters[i], %fieldPtr
631 //   %elementPtr = llvm.getelementptr %array[i]
632 //   llvm.store %fieldPtr, %elementPtr
633 // return %array
634 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
635     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
636   auto loc = launchOp.getLoc();
637   auto numKernelOperands = launchOp.getNumKernelOperands();
638   auto arguments = getTypeConverter()->promoteOperands(
639       loc, launchOp.getOperands().take_back(numKernelOperands),
640       adaptor.getOperands().take_back(numKernelOperands), builder);
641   auto numArguments = arguments.size();
642   SmallVector<Type, 4> argumentTypes;
643   argumentTypes.reserve(numArguments);
644   for (auto argument : arguments)
645     argumentTypes.push_back(argument.getType());
646   auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
647                                                            argumentTypes);
648   auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
649                                               builder.getI32IntegerAttr(1));
650   auto structPtr = builder.create<LLVM::AllocaOp>(
651       loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
652   auto arraySize = builder.create<LLVM::ConstantOp>(
653       loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
654   auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
655                                                  arraySize, /*alignment=*/0);
656   auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
657                                                builder.getI32IntegerAttr(0));
658   for (const auto &en : llvm::enumerate(arguments)) {
659     auto index = builder.create<LLVM::ConstantOp>(
660         loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
661     auto fieldPtr = builder.create<LLVM::GEPOp>(
662         loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
663         ArrayRef<Value>{zero, index.getResult()});
664     builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
665     auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
666                                                   arrayPtr, index.getResult());
667     auto casted =
668         builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
669     builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
670   }
671   return arrayPtr;
672 }
673 
674 // Generates an LLVM IR dialect global that contains the name of the given
675 // kernel function as a C string, and returns a pointer to its beginning.
676 // The code is essentially:
677 //
678 // llvm.global constant @kernel_name("function_name\00")
679 // func(...) {
680 //   %0 = llvm.addressof @kernel_name
681 //   %1 = llvm.constant (0 : index)
682 //   %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
683 // }
684 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
685     StringRef moduleName, StringRef name, Location loc,
686     OpBuilder &builder) const {
687   // Make sure the trailing zero is included in the constant.
688   std::vector<char> kernelName(name.begin(), name.end());
689   kernelName.push_back('\0');
690 
691   std::string globalName =
692       std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
693   return LLVM::createGlobalString(
694       loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
695       LLVM::Linkage::Internal);
696 }
697 
698 // Emits LLVM IR to launch a kernel function. Expects the module that contains
699 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
700 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
701 //
702 // %0 = call %binarygetter
703 // %1 = call %moduleLoad(%0)
704 // %2 = <see generateKernelNameConstant>
705 // %3 = call %moduleGetFunction(%1, %2)
706 // %4 = call %streamCreate()
707 // %5 = <see generateParamsArray>
708 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
709 // call %streamSynchronize(%4)
710 // call %streamDestroy(%4)
711 // call %moduleUnload(%1)
712 //
713 // If the op is async, the stream corresponds to the (single) async dependency
714 // as well as the async token the op produces.
715 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
716     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
717     ConversionPatternRewriter &rewriter) const {
718   if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
719     return failure();
720 
721   if (launchOp.asyncDependencies().size() > 1)
722     return rewriter.notifyMatchFailure(
723         launchOp, "Cannot convert with more than one async dependency.");
724 
725   // Fail when the synchronous version of the op has async dependencies. The
726   // lowering destroys the stream, and we do not want to check that there is no
727   // use of the stream after this op.
728   if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
729     return rewriter.notifyMatchFailure(
730         launchOp, "Cannot convert non-async op with async dependencies.");
731 
732   Location loc = launchOp.getLoc();
733 
734   // Create an LLVM global with CUBIN extracted from the kernel annotation and
735   // obtain a pointer to the first byte in it.
736   auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
737       launchOp, launchOp.getKernelModuleName());
738   assert(kernelModule && "expected a kernel module");
739 
740   auto binaryAttr =
741       kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
742   if (!binaryAttr) {
743     kernelModule.emitOpError()
744         << "missing " << gpuBinaryAnnotation << " attribute";
745     return failure();
746   }
747 
748   SmallString<128> nameBuffer(kernelModule.getName());
749   nameBuffer.append(kGpuBinaryStorageSuffix);
750   Value data =
751       LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
752                                binaryAttr.getValue(), LLVM::Linkage::Internal);
753 
754   auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
755   // Get the function from the module. The name corresponds to the name of
756   // the kernel function.
757   auto kernelName = generateKernelNameConstant(
758       launchOp.getKernelModuleName().getValue(),
759       launchOp.getKernelName().getValue(), loc, rewriter);
760   auto function = moduleGetFunctionCallBuilder.create(
761       loc, rewriter, {module.getResult(0), kernelName});
762   auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
763                                                 rewriter.getI32IntegerAttr(0));
764   Value stream =
765       adaptor.asyncDependencies().empty()
766           ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
767           : adaptor.asyncDependencies().front();
768   // Create array of pointers to kernel arguments.
769   auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
770   auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
771   Value dynamicSharedMemorySize = launchOp.dynamicSharedMemorySize()
772                                       ? launchOp.dynamicSharedMemorySize()
773                                       : zero;
774   launchKernelCallBuilder.create(
775       loc, rewriter,
776       {function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(),
777        adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(),
778        adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
779        /*extra=*/nullpointer});
780 
781   if (launchOp.asyncToken()) {
782     // Async launch: make dependent ops use the same stream.
783     rewriter.replaceOp(launchOp, {stream});
784   } else {
785     // Synchronize with host and destroy stream. This must be the stream created
786     // above (with no other uses) because we check that the synchronous version
787     // does not have any async dependencies.
788     streamSynchronizeCallBuilder.create(loc, rewriter, stream);
789     streamDestroyCallBuilder.create(loc, rewriter, stream);
790     rewriter.eraseOp(launchOp);
791   }
792   moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
793 
794   return success();
795 }
796 
797 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
798     gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
799     ConversionPatternRewriter &rewriter) const {
800   auto memRefType = memcpyOp.src().getType().cast<MemRefType>();
801 
802   if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
803       !isConvertibleAndHasIdentityMaps(memRefType) ||
804       failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
805     return failure();
806 
807   auto loc = memcpyOp.getLoc();
808 
809   MemRefDescriptor srcDesc(adaptor.src());
810   Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
811 
812   Type elementPtrType = getElementPtrType(memRefType);
813   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
814   Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
815                                               ArrayRef<Value>{numElements});
816   auto sizeBytes =
817       rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
818 
819   auto src = rewriter.create<LLVM::BitcastOp>(
820       loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc));
821   auto dst = rewriter.create<LLVM::BitcastOp>(
822       loc, llvmPointerType,
823       MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc));
824 
825   auto stream = adaptor.asyncDependencies().front();
826   memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
827 
828   rewriter.replaceOp(memcpyOp, {stream});
829 
830   return success();
831 }
832 
833 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
834     gpu::MemsetOp memsetOp, OpAdaptor adaptor,
835     ConversionPatternRewriter &rewriter) const {
836   auto memRefType = memsetOp.dst().getType().cast<MemRefType>();
837 
838   if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
839       !isConvertibleAndHasIdentityMaps(memRefType) ||
840       failed(isAsyncWithOneDependency(rewriter, memsetOp)))
841     return failure();
842 
843   auto loc = memsetOp.getLoc();
844 
845   Type valueType = adaptor.value().getType();
846   if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
847     return rewriter.notifyMatchFailure(memsetOp,
848                                        "value must be a 32 bit scalar");
849   }
850 
851   MemRefDescriptor dstDesc(adaptor.dst());
852   Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
853 
854   auto value =
855       rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value());
856   auto dst = rewriter.create<LLVM::BitcastOp>(
857       loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc));
858 
859   auto stream = adaptor.asyncDependencies().front();
860   memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
861 
862   rewriter.replaceOp(memsetOp, {stream});
863   return success();
864 }
865 
866 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
867     gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
868     ConversionPatternRewriter &rewriter) const {
869   Location loc = op.getLoc();
870   setDefaultDeviceCallBuilder.create(loc, rewriter, {adaptor.devIndex()});
871   rewriter.replaceOp(op, {});
872   return success();
873 }
874 
875 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
876 mlir::createGpuToLLVMConversionPass() {
877   return std::make_unique<GpuToLLVMConversionPass>();
878 }
879 
880 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
881                                                RewritePatternSet &patterns,
882                                                StringRef gpuBinaryAnnotation) {
883   converter.addConversion(
884       [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
885         return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
886       });
887   patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
888                ConvertDeallocOpToGpuRuntimeCallPattern,
889                ConvertHostRegisterOpToGpuRuntimeCallPattern,
890                ConvertMemcpyOpToGpuRuntimeCallPattern,
891                ConvertMemsetOpToGpuRuntimeCallPattern,
892                ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
893                ConvertWaitAsyncOpToGpuRuntimeCallPattern,
894                ConvertWaitOpToGpuRuntimeCallPattern,
895                ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
896   patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
897                                                            gpuBinaryAnnotation);
898   patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
899 }
900