1a5f9cda1SChristian Sigg //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2a5f9cda1SChristian Sigg //
3a5f9cda1SChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a5f9cda1SChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
5a5f9cda1SChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a5f9cda1SChristian Sigg //
7a5f9cda1SChristian Sigg //===----------------------------------------------------------------------===//
8a5f9cda1SChristian Sigg //
9a5f9cda1SChristian Sigg // This file implements a pass to convert gpu.launch_func op into a sequence of
10a5f9cda1SChristian Sigg // GPU runtime calls. As most of GPU runtimes does not have a stable published
11a5f9cda1SChristian Sigg // ABI, this pass uses a slim runtime layer that builds on top of the public
12a5f9cda1SChristian Sigg // API from GPU runtime headers.
13a5f9cda1SChristian Sigg //
14a5f9cda1SChristian Sigg //===----------------------------------------------------------------------===//
15a5f9cda1SChristian Sigg 
16a5f9cda1SChristian Sigg #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17a5f9cda1SChristian Sigg 
18a5f9cda1SChristian Sigg #include "../PassDetail.h"
19a54f4eaeSMogball #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
20a5f9cda1SChristian Sigg #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
21330838ebSRiver Riddle #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
225a7b9194SRiver Riddle #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
235a7b9194SRiver Riddle #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
2475e5f0aaSAlex Zinenko #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
25684dfe8aSAlex Zinenko #include "mlir/Conversion/LLVMCommon/Pattern.h"
2675e5f0aaSAlex Zinenko #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
27a5f9cda1SChristian Sigg #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
28a5f9cda1SChristian Sigg #include "mlir/Dialect/Async/IR/Async.h"
29d7ef488bSMogball #include "mlir/Dialect/GPU/IR/GPUDialect.h"
30d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h"
31a5f9cda1SChristian Sigg #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
32a5f9cda1SChristian Sigg #include "mlir/IR/Attributes.h"
33a5f9cda1SChristian Sigg #include "mlir/IR/Builders.h"
34a5f9cda1SChristian Sigg #include "mlir/IR/BuiltinOps.h"
35a5f9cda1SChristian Sigg #include "mlir/IR/BuiltinTypes.h"
36a5f9cda1SChristian Sigg 
37a5f9cda1SChristian Sigg #include "llvm/ADT/STLExtras.h"
38a5f9cda1SChristian Sigg #include "llvm/Support/Error.h"
39a5f9cda1SChristian Sigg #include "llvm/Support/FormatVariadic.h"
40a5f9cda1SChristian Sigg 
41a5f9cda1SChristian Sigg using namespace mlir;
42a5f9cda1SChristian Sigg 
43a5f9cda1SChristian Sigg static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
44a5f9cda1SChristian Sigg 
45a5f9cda1SChristian Sigg namespace {
46a5f9cda1SChristian Sigg 
47a5f9cda1SChristian Sigg class GpuToLLVMConversionPass
48a5f9cda1SChristian Sigg     : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
49a5f9cda1SChristian Sigg public:
50a5f9cda1SChristian Sigg   GpuToLLVMConversionPass() = default;
51a5f9cda1SChristian Sigg 
GpuToLLVMConversionPass(const GpuToLLVMConversionPass & other)52a5f9cda1SChristian Sigg   GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
53a5f9cda1SChristian Sigg       : GpuToLLVMConversionPassBase(other) {}
54a5f9cda1SChristian Sigg 
55a5f9cda1SChristian Sigg   // Run the dialect converter on the module.
56a5f9cda1SChristian Sigg   void runOnOperation() override;
57a5f9cda1SChristian Sigg 
58a5f9cda1SChristian Sigg private:
59a5f9cda1SChristian Sigg   Option<std::string> gpuBinaryAnnotation{
60a5f9cda1SChristian Sigg       *this, "gpu-binary-annotation",
61a5f9cda1SChristian Sigg       llvm::cl::desc("Annotation attribute string for GPU binary"),
62a5f9cda1SChristian Sigg       llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
63a5f9cda1SChristian Sigg };
64a5f9cda1SChristian Sigg 
65a5f9cda1SChristian Sigg struct FunctionCallBuilder {
FunctionCallBuilder__anon6cc9e7020111::FunctionCallBuilder66a5f9cda1SChristian Sigg   FunctionCallBuilder(StringRef functionName, Type returnType,
67a5f9cda1SChristian Sigg                       ArrayRef<Type> argumentTypes)
68a5f9cda1SChristian Sigg       : functionName(functionName),
69a5f9cda1SChristian Sigg         functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
70a5f9cda1SChristian Sigg   LLVM::CallOp create(Location loc, OpBuilder &builder,
71a5f9cda1SChristian Sigg                       ArrayRef<Value> arguments) const;
72a5f9cda1SChristian Sigg 
73a5f9cda1SChristian Sigg   StringRef functionName;
74a5f9cda1SChristian Sigg   LLVM::LLVMFunctionType functionType;
75a5f9cda1SChristian Sigg };
76a5f9cda1SChristian Sigg 
77a5f9cda1SChristian Sigg template <typename OpTy>
78a5f9cda1SChristian Sigg class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
79a5f9cda1SChristian Sigg public:
ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)80a5f9cda1SChristian Sigg   explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
81a5f9cda1SChristian Sigg       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
82a5f9cda1SChristian Sigg 
83a5f9cda1SChristian Sigg protected:
getNumElements(ConversionPatternRewriter & rewriter,Location loc,MemRefType type,MemRefDescriptor desc) const84361458b1SLoren Maggiore   Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
85361458b1SLoren Maggiore                        MemRefType type, MemRefDescriptor desc) const {
86361458b1SLoren Maggiore     return type.hasStaticShape()
87361458b1SLoren Maggiore                ? ConvertToLLVMPattern::createIndexConstant(
88361458b1SLoren Maggiore                      rewriter, loc, type.getNumElements())
89361458b1SLoren Maggiore                // For identity maps (verified by caller), the number of
90361458b1SLoren Maggiore                // elements is stride[0] * size[0].
91361458b1SLoren Maggiore                : rewriter.create<LLVM::MulOp>(loc,
92361458b1SLoren Maggiore                                               desc.stride(rewriter, loc, 0),
93361458b1SLoren Maggiore                                               desc.size(rewriter, loc, 0));
94361458b1SLoren Maggiore   }
95361458b1SLoren Maggiore 
96a5f9cda1SChristian Sigg   MLIRContext *context = &this->getTypeConverter()->getContext();
97a5f9cda1SChristian Sigg 
98a5f9cda1SChristian Sigg   Type llvmVoidType = LLVM::LLVMVoidType::get(context);
99a5f9cda1SChristian Sigg   Type llvmPointerType =
100a5f9cda1SChristian Sigg       LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
101a5f9cda1SChristian Sigg   Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType);
102a5f9cda1SChristian Sigg   Type llvmInt8Type = IntegerType::get(context, 8);
103a5f9cda1SChristian Sigg   Type llvmInt32Type = IntegerType::get(context, 32);
104a5f9cda1SChristian Sigg   Type llvmInt64Type = IntegerType::get(context, 64);
105a5f9cda1SChristian Sigg   Type llvmIntPtrType = IntegerType::get(
106a5f9cda1SChristian Sigg       context, this->getTypeConverter()->getPointerBitwidth(0));
107a5f9cda1SChristian Sigg 
108a5f9cda1SChristian Sigg   FunctionCallBuilder moduleLoadCallBuilder = {
109a5f9cda1SChristian Sigg       "mgpuModuleLoad",
110a5f9cda1SChristian Sigg       llvmPointerType /* void *module */,
111a5f9cda1SChristian Sigg       {llvmPointerType /* void *cubin */}};
112a5f9cda1SChristian Sigg   FunctionCallBuilder moduleUnloadCallBuilder = {
113a5f9cda1SChristian Sigg       "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
114a5f9cda1SChristian Sigg   FunctionCallBuilder moduleGetFunctionCallBuilder = {
115a5f9cda1SChristian Sigg       "mgpuModuleGetFunction",
116a5f9cda1SChristian Sigg       llvmPointerType /* void *function */,
117a5f9cda1SChristian Sigg       {
118a5f9cda1SChristian Sigg           llvmPointerType, /* void *module */
119a5f9cda1SChristian Sigg           llvmPointerType  /* char *name   */
120a5f9cda1SChristian Sigg       }};
121a5f9cda1SChristian Sigg   FunctionCallBuilder launchKernelCallBuilder = {
122a5f9cda1SChristian Sigg       "mgpuLaunchKernel",
123a5f9cda1SChristian Sigg       llvmVoidType,
124a5f9cda1SChristian Sigg       {
125a5f9cda1SChristian Sigg           llvmPointerType,        /* void* f */
126a5f9cda1SChristian Sigg           llvmIntPtrType,         /* intptr_t gridXDim */
127a5f9cda1SChristian Sigg           llvmIntPtrType,         /* intptr_t gridyDim */
128a5f9cda1SChristian Sigg           llvmIntPtrType,         /* intptr_t gridZDim */
129a5f9cda1SChristian Sigg           llvmIntPtrType,         /* intptr_t blockXDim */
130a5f9cda1SChristian Sigg           llvmIntPtrType,         /* intptr_t blockYDim */
131a5f9cda1SChristian Sigg           llvmIntPtrType,         /* intptr_t blockZDim */
132a5f9cda1SChristian Sigg           llvmInt32Type,          /* unsigned int sharedMemBytes */
133a5f9cda1SChristian Sigg           llvmPointerType,        /* void *hstream */
134a5f9cda1SChristian Sigg           llvmPointerPointerType, /* void **kernelParams */
135a5f9cda1SChristian Sigg           llvmPointerPointerType  /* void **extra */
136a5f9cda1SChristian Sigg       }};
137a5f9cda1SChristian Sigg   FunctionCallBuilder streamCreateCallBuilder = {
138a5f9cda1SChristian Sigg       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
139a5f9cda1SChristian Sigg   FunctionCallBuilder streamDestroyCallBuilder = {
140a5f9cda1SChristian Sigg       "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
141a5f9cda1SChristian Sigg   FunctionCallBuilder streamSynchronizeCallBuilder = {
142a5f9cda1SChristian Sigg       "mgpuStreamSynchronize",
143a5f9cda1SChristian Sigg       llvmVoidType,
144a5f9cda1SChristian Sigg       {llvmPointerType /* void *stream */}};
145a5f9cda1SChristian Sigg   FunctionCallBuilder streamWaitEventCallBuilder = {
146a5f9cda1SChristian Sigg       "mgpuStreamWaitEvent",
147a5f9cda1SChristian Sigg       llvmVoidType,
148a5f9cda1SChristian Sigg       {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
149a5f9cda1SChristian Sigg   FunctionCallBuilder eventCreateCallBuilder = {
150a5f9cda1SChristian Sigg       "mgpuEventCreate", llvmPointerType /* void *event */, {}};
151a5f9cda1SChristian Sigg   FunctionCallBuilder eventDestroyCallBuilder = {
152a5f9cda1SChristian Sigg       "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
153a5f9cda1SChristian Sigg   FunctionCallBuilder eventSynchronizeCallBuilder = {
154a5f9cda1SChristian Sigg       "mgpuEventSynchronize",
155a5f9cda1SChristian Sigg       llvmVoidType,
156a5f9cda1SChristian Sigg       {llvmPointerType /* void *event */}};
157a5f9cda1SChristian Sigg   FunctionCallBuilder eventRecordCallBuilder = {
158a5f9cda1SChristian Sigg       "mgpuEventRecord",
159a5f9cda1SChristian Sigg       llvmVoidType,
160a5f9cda1SChristian Sigg       {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
161a5f9cda1SChristian Sigg   FunctionCallBuilder hostRegisterCallBuilder = {
162a5f9cda1SChristian Sigg       "mgpuMemHostRegisterMemRef",
163a5f9cda1SChristian Sigg       llvmVoidType,
164a5f9cda1SChristian Sigg       {llvmIntPtrType /* intptr_t rank */,
165a5f9cda1SChristian Sigg        llvmPointerType /* void *memrefDesc */,
166a5f9cda1SChristian Sigg        llvmIntPtrType /* intptr_t elementSizeBytes */}};
167a5f9cda1SChristian Sigg   FunctionCallBuilder allocCallBuilder = {
168a5f9cda1SChristian Sigg       "mgpuMemAlloc",
169a5f9cda1SChristian Sigg       llvmPointerType /* void * */,
170a5f9cda1SChristian Sigg       {llvmIntPtrType /* intptr_t sizeBytes */,
171a5f9cda1SChristian Sigg        llvmPointerType /* void *stream */}};
172a5f9cda1SChristian Sigg   FunctionCallBuilder deallocCallBuilder = {
173a5f9cda1SChristian Sigg       "mgpuMemFree",
174a5f9cda1SChristian Sigg       llvmVoidType,
175a5f9cda1SChristian Sigg       {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
176a5f9cda1SChristian Sigg   FunctionCallBuilder memcpyCallBuilder = {
177a5f9cda1SChristian Sigg       "mgpuMemcpy",
178a5f9cda1SChristian Sigg       llvmVoidType,
179a5f9cda1SChristian Sigg       {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
180a5f9cda1SChristian Sigg        llvmIntPtrType /* intptr_t sizeBytes */,
181a5f9cda1SChristian Sigg        llvmPointerType /* void *stream */}};
182361458b1SLoren Maggiore   FunctionCallBuilder memsetCallBuilder = {
183361458b1SLoren Maggiore       "mgpuMemset32",
184361458b1SLoren Maggiore       llvmVoidType,
185361458b1SLoren Maggiore       {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
186361458b1SLoren Maggiore        llvmIntPtrType /* intptr_t sizeBytes */,
187361458b1SLoren Maggiore        llvmPointerType /* void *stream */}};
18884718d37SKrzysztof Drewniak   FunctionCallBuilder setDefaultDeviceCallBuilder = {
18984718d37SKrzysztof Drewniak       "mgpuSetDefaultDevice",
19084718d37SKrzysztof Drewniak       llvmVoidType,
19184718d37SKrzysztof Drewniak       {llvmInt32Type /* uint32_t devIndex */}};
192a5f9cda1SChristian Sigg };
193a5f9cda1SChristian Sigg 
194a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
195a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
196a5f9cda1SChristian Sigg class ConvertHostRegisterOpToGpuRuntimeCallPattern
197a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
198a5f9cda1SChristian Sigg public:
ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)199a5f9cda1SChristian Sigg   ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
200a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
201a5f9cda1SChristian Sigg 
202a5f9cda1SChristian Sigg private:
203a5f9cda1SChristian Sigg   LogicalResult
204ef976337SRiver Riddle   matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
205a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
206a5f9cda1SChristian Sigg };
207a5f9cda1SChristian Sigg 
208a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
209a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
210a5f9cda1SChristian Sigg class ConvertAllocOpToGpuRuntimeCallPattern
211a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
212a5f9cda1SChristian Sigg public:
ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)213a5f9cda1SChristian Sigg   ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
214a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
215a5f9cda1SChristian Sigg 
216a5f9cda1SChristian Sigg private:
217a5f9cda1SChristian Sigg   LogicalResult
218ef976337SRiver Riddle   matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
219a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
220a5f9cda1SChristian Sigg };
221a5f9cda1SChristian Sigg 
222a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
223a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
224a5f9cda1SChristian Sigg class ConvertDeallocOpToGpuRuntimeCallPattern
225a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
226a5f9cda1SChristian Sigg public:
ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)227a5f9cda1SChristian Sigg   ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
228a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
229a5f9cda1SChristian Sigg 
230a5f9cda1SChristian Sigg private:
231a5f9cda1SChristian Sigg   LogicalResult
232ef976337SRiver Riddle   matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
233a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
234a5f9cda1SChristian Sigg };
235a5f9cda1SChristian Sigg 
236a5f9cda1SChristian Sigg class ConvertAsyncYieldToGpuRuntimeCallPattern
237a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
238a5f9cda1SChristian Sigg public:
ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)239a5f9cda1SChristian Sigg   ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
240a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
241a5f9cda1SChristian Sigg 
242a5f9cda1SChristian Sigg private:
243a5f9cda1SChristian Sigg   LogicalResult
244ef976337SRiver Riddle   matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
245a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
246a5f9cda1SChristian Sigg };
247a5f9cda1SChristian Sigg 
248a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
249a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
250a5f9cda1SChristian Sigg class ConvertWaitOpToGpuRuntimeCallPattern
251a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
252a5f9cda1SChristian Sigg public:
ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)253a5f9cda1SChristian Sigg   ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
254a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
255a5f9cda1SChristian Sigg 
256a5f9cda1SChristian Sigg private:
257a5f9cda1SChristian Sigg   LogicalResult
258ef976337SRiver Riddle   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
259a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
260a5f9cda1SChristian Sigg };
261a5f9cda1SChristian Sigg 
262a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
263a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
264a5f9cda1SChristian Sigg class ConvertWaitAsyncOpToGpuRuntimeCallPattern
265a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
266a5f9cda1SChristian Sigg public:
ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)267a5f9cda1SChristian Sigg   ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
268a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
269a5f9cda1SChristian Sigg 
270a5f9cda1SChristian Sigg private:
271a5f9cda1SChristian Sigg   LogicalResult
272ef976337SRiver Riddle   matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
273a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
274a5f9cda1SChristian Sigg };
275a5f9cda1SChristian Sigg 
276a5f9cda1SChristian Sigg /// A rewrite patter to convert gpu.launch_func operations into a sequence of
277a5f9cda1SChristian Sigg /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
278a5f9cda1SChristian Sigg ///
279a5f9cda1SChristian Sigg /// In essence, a gpu.launch_func operations gets compiled into the following
280a5f9cda1SChristian Sigg /// sequence of runtime calls:
281a5f9cda1SChristian Sigg ///
282a5f9cda1SChristian Sigg /// * moduleLoad        -- loads the module given the cubin / hsaco data
283a5f9cda1SChristian Sigg /// * moduleGetFunction -- gets a handle to the actual kernel function
284a5f9cda1SChristian Sigg /// * getStreamHelper   -- initializes a new compute stream on GPU
285a5f9cda1SChristian Sigg /// * launchKernel      -- launches the kernel on a stream
286a5f9cda1SChristian Sigg /// * streamSynchronize -- waits for operations on the stream to finish
287a5f9cda1SChristian Sigg ///
288a5f9cda1SChristian Sigg /// Intermediate data structures are allocated on the stack.
289a5f9cda1SChristian Sigg class ConvertLaunchFuncOpToGpuRuntimeCallPattern
290a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
291a5f9cda1SChristian Sigg public:
ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter,StringRef gpuBinaryAnnotation)292a5f9cda1SChristian Sigg   ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
293a5f9cda1SChristian Sigg                                              StringRef gpuBinaryAnnotation)
294a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
295a5f9cda1SChristian Sigg         gpuBinaryAnnotation(gpuBinaryAnnotation) {}
296a5f9cda1SChristian Sigg 
297a5f9cda1SChristian Sigg private:
298ef976337SRiver Riddle   Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
299ef976337SRiver Riddle                             OpBuilder &builder) const;
300a5f9cda1SChristian Sigg   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
301a5f9cda1SChristian Sigg                                    Location loc, OpBuilder &builder) const;
302a5f9cda1SChristian Sigg 
303a5f9cda1SChristian Sigg   LogicalResult
304ef976337SRiver Riddle   matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
305a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
306a5f9cda1SChristian Sigg 
307a5f9cda1SChristian Sigg   llvm::SmallString<32> gpuBinaryAnnotation;
308a5f9cda1SChristian Sigg };
309a5f9cda1SChristian Sigg 
310a5f9cda1SChristian Sigg class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
311a5f9cda1SChristian Sigg   using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
312a5f9cda1SChristian Sigg 
matchAndRewrite(gpu::GPUModuleOp op,PatternRewriter & rewriter) const313a5f9cda1SChristian Sigg   LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
314a5f9cda1SChristian Sigg                                 PatternRewriter &rewriter) const override {
315a5f9cda1SChristian Sigg     // GPU kernel modules are no longer necessary since we have a global
316a5f9cda1SChristian Sigg     // constant with the CUBIN, or HSACO data.
317a5f9cda1SChristian Sigg     rewriter.eraseOp(op);
318a5f9cda1SChristian Sigg     return success();
319a5f9cda1SChristian Sigg   }
320a5f9cda1SChristian Sigg };
321a5f9cda1SChristian Sigg 
322a5f9cda1SChristian Sigg /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
323a5f9cda1SChristian Sigg /// call. Currently it supports CUDA and ROCm (HIP).
324a5f9cda1SChristian Sigg class ConvertMemcpyOpToGpuRuntimeCallPattern
325a5f9cda1SChristian Sigg     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
326a5f9cda1SChristian Sigg public:
ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)327a5f9cda1SChristian Sigg   ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
328a5f9cda1SChristian Sigg       : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
329a5f9cda1SChristian Sigg 
330a5f9cda1SChristian Sigg private:
331a5f9cda1SChristian Sigg   LogicalResult
332ef976337SRiver Riddle   matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
333a5f9cda1SChristian Sigg                   ConversionPatternRewriter &rewriter) const override;
334a5f9cda1SChristian Sigg };
335361458b1SLoren Maggiore 
336361458b1SLoren Maggiore /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
337361458b1SLoren Maggiore /// call. Currently it supports CUDA and ROCm (HIP).
338361458b1SLoren Maggiore class ConvertMemsetOpToGpuRuntimeCallPattern
339361458b1SLoren Maggiore     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
340361458b1SLoren Maggiore public:
ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)341361458b1SLoren Maggiore   ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
342361458b1SLoren Maggiore       : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
343361458b1SLoren Maggiore 
344361458b1SLoren Maggiore private:
345361458b1SLoren Maggiore   LogicalResult
346ef976337SRiver Riddle   matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
347361458b1SLoren Maggiore                   ConversionPatternRewriter &rewriter) const override;
348361458b1SLoren Maggiore };
34984718d37SKrzysztof Drewniak 
35084718d37SKrzysztof Drewniak /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
35184718d37SKrzysztof Drewniak /// Currently supports CUDA and ROCm (HIP)
35284718d37SKrzysztof Drewniak class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
35384718d37SKrzysztof Drewniak     : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
35484718d37SKrzysztof Drewniak public:
ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)35584718d37SKrzysztof Drewniak   ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
35684718d37SKrzysztof Drewniak       LLVMTypeConverter &typeConverter)
35784718d37SKrzysztof Drewniak       : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
35884718d37SKrzysztof Drewniak             typeConverter) {}
35984718d37SKrzysztof Drewniak 
36084718d37SKrzysztof Drewniak   LogicalResult
36184718d37SKrzysztof Drewniak   matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
36284718d37SKrzysztof Drewniak                   ConversionPatternRewriter &rewriter) const override;
36384718d37SKrzysztof Drewniak };
364a5f9cda1SChristian Sigg } // namespace
365a5f9cda1SChristian Sigg 
runOnOperation()366a5f9cda1SChristian Sigg void GpuToLLVMConversionPass::runOnOperation() {
367a5f9cda1SChristian Sigg   LLVMTypeConverter converter(&getContext());
368dc4e913bSChris Lattner   RewritePatternSet patterns(&getContext());
369a5f9cda1SChristian Sigg   LLVMConversionTarget target(getContext());
370a5f9cda1SChristian Sigg 
371abe501f2SChristian Sigg   target.addIllegalDialect<gpu::GPUDialect>();
372abe501f2SChristian Sigg 
373a54f4eaeSMogball   mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
374330838ebSRiver Riddle   mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
375a5f9cda1SChristian Sigg   populateVectorToLLVMConversionPatterns(converter, patterns);
37675e5f0aaSAlex Zinenko   populateMemRefToLLVMConversionPatterns(converter, patterns);
3775a7b9194SRiver Riddle   populateFuncToLLVMConversionPatterns(converter, patterns);
3783a506b31SChris Lattner   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
3793a506b31SChris Lattner                                                     target);
3807d855605SButygin   populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
381a5f9cda1SChristian Sigg 
382a5f9cda1SChristian Sigg   if (failed(
383a5f9cda1SChristian Sigg           applyPartialConversion(getOperation(), target, std::move(patterns))))
384a5f9cda1SChristian Sigg     signalPassFailure();
385a5f9cda1SChristian Sigg }
386a5f9cda1SChristian Sigg 
create(Location loc,OpBuilder & builder,ArrayRef<Value> arguments) const387a5f9cda1SChristian Sigg LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
388a5f9cda1SChristian Sigg                                          ArrayRef<Value> arguments) const {
389a5f9cda1SChristian Sigg   auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
390a5f9cda1SChristian Sigg   auto function = [&] {
391a5f9cda1SChristian Sigg     if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
392a5f9cda1SChristian Sigg       return function;
393973ddb7dSMehdi Amini     return OpBuilder::atBlockEnd(module.getBody())
394a5f9cda1SChristian Sigg         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
395a5f9cda1SChristian Sigg   }();
396c1f719d1SAlex Zinenko   return builder.create<LLVM::CallOp>(loc, function, arguments);
397a5f9cda1SChristian Sigg }
398a5f9cda1SChristian Sigg 
399a5f9cda1SChristian Sigg // Returns whether all operands are of LLVM type.
areAllLLVMTypes(Operation * op,ValueRange operands,ConversionPatternRewriter & rewriter)400a5f9cda1SChristian Sigg static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
401a5f9cda1SChristian Sigg                                      ConversionPatternRewriter &rewriter) {
402a5f9cda1SChristian Sigg   if (!llvm::all_of(operands, [](Value value) {
403a5f9cda1SChristian Sigg         return LLVM::isCompatibleType(value.getType());
404a5f9cda1SChristian Sigg       }))
405a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(
406a5f9cda1SChristian Sigg         op, "Cannot convert if operands aren't of LLVM type.");
407a5f9cda1SChristian Sigg   return success();
408a5f9cda1SChristian Sigg }
409a5f9cda1SChristian Sigg 
410a5f9cda1SChristian Sigg static LogicalResult
isAsyncWithOneDependency(ConversionPatternRewriter & rewriter,gpu::AsyncOpInterface op)411a5f9cda1SChristian Sigg isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
412a5f9cda1SChristian Sigg                          gpu::AsyncOpInterface op) {
413a5f9cda1SChristian Sigg   if (op.getAsyncDependencies().size() != 1)
414a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(
415a5f9cda1SChristian Sigg         op, "Can only convert with exactly one async dependency.");
416a5f9cda1SChristian Sigg 
417a5f9cda1SChristian Sigg   if (!op.getAsyncToken())
418a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(op, "Can convert only async version.");
419a5f9cda1SChristian Sigg 
420a5f9cda1SChristian Sigg   return success();
421a5f9cda1SChristian Sigg }
422a5f9cda1SChristian Sigg 
matchAndRewrite(gpu::HostRegisterOp hostRegisterOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const423a5f9cda1SChristian Sigg LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
424ef976337SRiver Riddle     gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
425a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
426a5f9cda1SChristian Sigg   auto *op = hostRegisterOp.getOperation();
427ef976337SRiver Riddle   if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
428a5f9cda1SChristian Sigg     return failure();
429a5f9cda1SChristian Sigg 
430a5f9cda1SChristian Sigg   Location loc = op->getLoc();
431a5f9cda1SChristian Sigg 
432a5f9cda1SChristian Sigg   auto memRefType = hostRegisterOp.value().getType();
433a5f9cda1SChristian Sigg   auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
434a5f9cda1SChristian Sigg   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
435a5f9cda1SChristian Sigg 
436ef976337SRiver Riddle   auto arguments = getTypeConverter()->promoteOperands(
437ef976337SRiver Riddle       loc, op->getOperands(), adaptor.getOperands(), rewriter);
438a5f9cda1SChristian Sigg   arguments.push_back(elementSize);
439a5f9cda1SChristian Sigg   hostRegisterCallBuilder.create(loc, rewriter, arguments);
440a5f9cda1SChristian Sigg 
441a5f9cda1SChristian Sigg   rewriter.eraseOp(op);
442a5f9cda1SChristian Sigg   return success();
443a5f9cda1SChristian Sigg }
444a5f9cda1SChristian Sigg 
matchAndRewrite(gpu::AllocOp allocOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const445a5f9cda1SChristian Sigg LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
446ef976337SRiver Riddle     gpu::AllocOp allocOp, OpAdaptor adaptor,
447a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
448a5f9cda1SChristian Sigg   MemRefType memRefType = allocOp.getType();
449a5f9cda1SChristian Sigg 
450ef976337SRiver Riddle   if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
451a5f9cda1SChristian Sigg       !isConvertibleAndHasIdentityMaps(memRefType) ||
452a5f9cda1SChristian Sigg       failed(isAsyncWithOneDependency(rewriter, allocOp)))
453a5f9cda1SChristian Sigg     return failure();
454a5f9cda1SChristian Sigg 
455a5f9cda1SChristian Sigg   auto loc = allocOp.getLoc();
456a5f9cda1SChristian Sigg 
457a5f9cda1SChristian Sigg   // Get shape of the memref as values: static sizes are constant
458a5f9cda1SChristian Sigg   // values and dynamic sizes are passed to 'alloc' as operands.
459a5f9cda1SChristian Sigg   SmallVector<Value, 4> shape;
460a5f9cda1SChristian Sigg   SmallVector<Value, 4> strides;
461a5f9cda1SChristian Sigg   Value sizeBytes;
462a5f9cda1SChristian Sigg   getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter,
463a5f9cda1SChristian Sigg                            shape, strides, sizeBytes);
464a5f9cda1SChristian Sigg 
465a5f9cda1SChristian Sigg   // Allocate the underlying buffer and store a pointer to it in the MemRef
466a5f9cda1SChristian Sigg   // descriptor.
467a5f9cda1SChristian Sigg   Type elementPtrType = this->getElementPtrType(memRefType);
468a5f9cda1SChristian Sigg   auto stream = adaptor.asyncDependencies().front();
469a5f9cda1SChristian Sigg   Value allocatedPtr =
470a5f9cda1SChristian Sigg       allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
471a5f9cda1SChristian Sigg   allocatedPtr =
472a5f9cda1SChristian Sigg       rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
473a5f9cda1SChristian Sigg 
474a5f9cda1SChristian Sigg   // No alignment.
475a5f9cda1SChristian Sigg   Value alignedPtr = allocatedPtr;
476a5f9cda1SChristian Sigg 
477a5f9cda1SChristian Sigg   // Create the MemRef descriptor.
478a5f9cda1SChristian Sigg   auto memRefDescriptor = this->createMemRefDescriptor(
479a5f9cda1SChristian Sigg       loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
480a5f9cda1SChristian Sigg 
481a5f9cda1SChristian Sigg   rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
482a5f9cda1SChristian Sigg 
483a5f9cda1SChristian Sigg   return success();
484a5f9cda1SChristian Sigg }
485a5f9cda1SChristian Sigg 
matchAndRewrite(gpu::DeallocOp deallocOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const486a5f9cda1SChristian Sigg LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
487ef976337SRiver Riddle     gpu::DeallocOp deallocOp, OpAdaptor adaptor,
488a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
489ef976337SRiver Riddle   if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
490a5f9cda1SChristian Sigg       failed(isAsyncWithOneDependency(rewriter, deallocOp)))
491a5f9cda1SChristian Sigg     return failure();
492a5f9cda1SChristian Sigg 
493a5f9cda1SChristian Sigg   Location loc = deallocOp.getLoc();
494a5f9cda1SChristian Sigg 
495a5f9cda1SChristian Sigg   Value pointer =
496a5f9cda1SChristian Sigg       MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
497a5f9cda1SChristian Sigg   auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
498a5f9cda1SChristian Sigg   Value stream = adaptor.asyncDependencies().front();
499a5f9cda1SChristian Sigg   deallocCallBuilder.create(loc, rewriter, {casted, stream});
500a5f9cda1SChristian Sigg 
501a5f9cda1SChristian Sigg   rewriter.replaceOp(deallocOp, {stream});
502a5f9cda1SChristian Sigg   return success();
503a5f9cda1SChristian Sigg }
504a5f9cda1SChristian Sigg 
isGpuAsyncTokenType(Value value)505a5f9cda1SChristian Sigg static bool isGpuAsyncTokenType(Value value) {
506a5f9cda1SChristian Sigg   return value.getType().isa<gpu::AsyncTokenType>();
507a5f9cda1SChristian Sigg }
508a5f9cda1SChristian Sigg 
509a5f9cda1SChristian Sigg // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
510a5f9cda1SChristian Sigg // !gpu.async.token are lowered to stream within the async.execute region, but
511a5f9cda1SChristian Sigg // are passed as events between them. For each !gpu.async.token operand, we
512a5f9cda1SChristian Sigg // create an event and record it on the stream.
matchAndRewrite(async::YieldOp yieldOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const513a5f9cda1SChristian Sigg LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
514ef976337SRiver Riddle     async::YieldOp yieldOp, OpAdaptor adaptor,
515a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
516a5f9cda1SChristian Sigg   if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType))
517a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
518a5f9cda1SChristian Sigg 
519a5f9cda1SChristian Sigg   Location loc = yieldOp.getLoc();
520ef976337SRiver Riddle   SmallVector<Value, 4> newOperands(adaptor.getOperands());
521a5f9cda1SChristian Sigg   llvm::SmallDenseSet<Value> streams;
522a5f9cda1SChristian Sigg   for (auto &operand : yieldOp->getOpOperands()) {
523a5f9cda1SChristian Sigg     if (!isGpuAsyncTokenType(operand.get()))
524a5f9cda1SChristian Sigg       continue;
525a5f9cda1SChristian Sigg     auto idx = operand.getOperandNumber();
526ef976337SRiver Riddle     auto stream = adaptor.getOperands()[idx];
527a5f9cda1SChristian Sigg     auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
528a5f9cda1SChristian Sigg     eventRecordCallBuilder.create(loc, rewriter, {event, stream});
529a5f9cda1SChristian Sigg     newOperands[idx] = event;
530a5f9cda1SChristian Sigg     streams.insert(stream);
531a5f9cda1SChristian Sigg   }
532a5f9cda1SChristian Sigg   for (auto stream : streams)
533a5f9cda1SChristian Sigg     streamDestroyCallBuilder.create(loc, rewriter, {stream});
534a5f9cda1SChristian Sigg 
535a5f9cda1SChristian Sigg   rewriter.updateRootInPlace(yieldOp,
536a5f9cda1SChristian Sigg                              [&] { yieldOp->setOperands(newOperands); });
537a5f9cda1SChristian Sigg   return success();
538a5f9cda1SChristian Sigg }
539a5f9cda1SChristian Sigg 
540a5f9cda1SChristian Sigg // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
isDefinedByCallTo(Value value,StringRef functionName)541a5f9cda1SChristian Sigg static bool isDefinedByCallTo(Value value, StringRef functionName) {
542a5f9cda1SChristian Sigg   assert(value.getType().isa<LLVM::LLVMPointerType>());
543a5f9cda1SChristian Sigg   if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
544cfb72fd3SJacques Pienaar     return defOp.getCallee()->equals(functionName);
545a5f9cda1SChristian Sigg   return false;
546a5f9cda1SChristian Sigg }
547a5f9cda1SChristian Sigg 
548a5f9cda1SChristian Sigg // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
549a5f9cda1SChristian Sigg // with the stream/event operands. The operands are destroyed. That is, it
550a5f9cda1SChristian Sigg // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
551a5f9cda1SChristian Sigg // runtime error. Eventually, we should guarantee this property.
matchAndRewrite(gpu::WaitOp waitOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const552a5f9cda1SChristian Sigg LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
553ef976337SRiver Riddle     gpu::WaitOp waitOp, OpAdaptor adaptor,
554a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
555a5f9cda1SChristian Sigg   if (waitOp.asyncToken())
556a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
557a5f9cda1SChristian Sigg 
558a5f9cda1SChristian Sigg   Location loc = waitOp.getLoc();
559a5f9cda1SChristian Sigg 
560ef976337SRiver Riddle   for (auto operand : adaptor.getOperands()) {
561a5f9cda1SChristian Sigg     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
562a5f9cda1SChristian Sigg       // The converted operand's definition created a stream.
563a5f9cda1SChristian Sigg       streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
564a5f9cda1SChristian Sigg       streamDestroyCallBuilder.create(loc, rewriter, {operand});
565a5f9cda1SChristian Sigg     } else {
566a5f9cda1SChristian Sigg       // Otherwise the converted operand is an event. This assumes that we use
567a5f9cda1SChristian Sigg       // events in control flow code as well.
568a5f9cda1SChristian Sigg       eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
569a5f9cda1SChristian Sigg       eventDestroyCallBuilder.create(loc, rewriter, {operand});
570a5f9cda1SChristian Sigg     }
571a5f9cda1SChristian Sigg   }
572a5f9cda1SChristian Sigg 
573a5f9cda1SChristian Sigg   rewriter.eraseOp(waitOp);
574a5f9cda1SChristian Sigg   return success();
575a5f9cda1SChristian Sigg }
576a5f9cda1SChristian Sigg 
577a5f9cda1SChristian Sigg // Converts `gpu.wait async` to runtime calls. The converted op creates a new
578a5f9cda1SChristian Sigg // stream that is synchronized with stream/event operands. The operands are
579a5f9cda1SChristian Sigg // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
580a5f9cda1SChristian Sigg // Otherwise we will get a runtime error. Eventually, we should guarantee this
581a5f9cda1SChristian Sigg // property.
matchAndRewrite(gpu::WaitOp waitOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const582a5f9cda1SChristian Sigg LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
583ef976337SRiver Riddle     gpu::WaitOp waitOp, OpAdaptor adaptor,
584a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
585a5f9cda1SChristian Sigg   if (!waitOp.asyncToken())
586a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
587a5f9cda1SChristian Sigg 
588a5f9cda1SChristian Sigg   Location loc = waitOp.getLoc();
589a5f9cda1SChristian Sigg 
590a5f9cda1SChristian Sigg   auto insertionPoint = rewriter.saveInsertionPoint();
591a5f9cda1SChristian Sigg   SmallVector<Value, 1> events;
592ef976337SRiver Riddle   for (auto pair :
593ef976337SRiver Riddle        llvm::zip(waitOp.asyncDependencies(), adaptor.getOperands())) {
594a5f9cda1SChristian Sigg     auto operand = std::get<1>(pair);
595a5f9cda1SChristian Sigg     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
596a5f9cda1SChristian Sigg       // The converted operand's definition created a stream. Insert an event
597a5f9cda1SChristian Sigg       // into the stream just after the last use of the original token operand.
598a5f9cda1SChristian Sigg       auto *defOp = std::get<0>(pair).getDefiningOp();
599a5f9cda1SChristian Sigg       rewriter.setInsertionPointAfter(defOp);
600a5f9cda1SChristian Sigg       auto event =
601a5f9cda1SChristian Sigg           eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
602a5f9cda1SChristian Sigg       eventRecordCallBuilder.create(loc, rewriter, {event, operand});
603a5f9cda1SChristian Sigg       events.push_back(event);
604a5f9cda1SChristian Sigg     } else {
605a5f9cda1SChristian Sigg       // Otherwise the converted operand is an event. This assumes that we use
606a5f9cda1SChristian Sigg       // events in control flow code as well.
607a5f9cda1SChristian Sigg       events.push_back(operand);
608a5f9cda1SChristian Sigg     }
609a5f9cda1SChristian Sigg   }
610a5f9cda1SChristian Sigg   rewriter.restoreInsertionPoint(insertionPoint);
611a5f9cda1SChristian Sigg   auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
612a5f9cda1SChristian Sigg   for (auto event : events)
613a5f9cda1SChristian Sigg     streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
614a5f9cda1SChristian Sigg   for (auto event : events)
615a5f9cda1SChristian Sigg     eventDestroyCallBuilder.create(loc, rewriter, {event});
616a5f9cda1SChristian Sigg   rewriter.replaceOp(waitOp, {stream});
617a5f9cda1SChristian Sigg 
618a5f9cda1SChristian Sigg   return success();
619a5f9cda1SChristian Sigg }
620a5f9cda1SChristian Sigg 
621a5f9cda1SChristian Sigg // Creates a struct containing all kernel parameters on the stack and returns
622a5f9cda1SChristian Sigg // an array of type-erased pointers to the fields of the struct. The array can
623a5f9cda1SChristian Sigg // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
624a5f9cda1SChristian Sigg // The generated code is essentially as follows:
625a5f9cda1SChristian Sigg //
626a5f9cda1SChristian Sigg // %struct = alloca(sizeof(struct { Parameters... }))
627a5f9cda1SChristian Sigg // %array = alloca(NumParameters * sizeof(void *))
628a5f9cda1SChristian Sigg // for (i : [0, NumParameters))
629a5f9cda1SChristian Sigg //   %fieldPtr = llvm.getelementptr %struct[0, i]
630a5f9cda1SChristian Sigg //   llvm.store parameters[i], %fieldPtr
631a5f9cda1SChristian Sigg //   %elementPtr = llvm.getelementptr %array[i]
632a5f9cda1SChristian Sigg //   llvm.store %fieldPtr, %elementPtr
633a5f9cda1SChristian Sigg // return %array
generateParamsArray(gpu::LaunchFuncOp launchOp,OpAdaptor adaptor,OpBuilder & builder) const634a5f9cda1SChristian Sigg Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
635ef976337SRiver Riddle     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
636a5f9cda1SChristian Sigg   auto loc = launchOp.getLoc();
637a5f9cda1SChristian Sigg   auto numKernelOperands = launchOp.getNumKernelOperands();
638a5f9cda1SChristian Sigg   auto arguments = getTypeConverter()->promoteOperands(
639a5f9cda1SChristian Sigg       loc, launchOp.getOperands().take_back(numKernelOperands),
640ef976337SRiver Riddle       adaptor.getOperands().take_back(numKernelOperands), builder);
641a5f9cda1SChristian Sigg   auto numArguments = arguments.size();
642a5f9cda1SChristian Sigg   SmallVector<Type, 4> argumentTypes;
643a5f9cda1SChristian Sigg   argumentTypes.reserve(numArguments);
644a5f9cda1SChristian Sigg   for (auto argument : arguments)
645a5f9cda1SChristian Sigg     argumentTypes.push_back(argument.getType());
646a5f9cda1SChristian Sigg   auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
647a5f9cda1SChristian Sigg                                                            argumentTypes);
648a5f9cda1SChristian Sigg   auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
649a5f9cda1SChristian Sigg                                               builder.getI32IntegerAttr(1));
650a5f9cda1SChristian Sigg   auto structPtr = builder.create<LLVM::AllocaOp>(
651a5f9cda1SChristian Sigg       loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
652a5f9cda1SChristian Sigg   auto arraySize = builder.create<LLVM::ConstantOp>(
653a5f9cda1SChristian Sigg       loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
654a5f9cda1SChristian Sigg   auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
655a5f9cda1SChristian Sigg                                                  arraySize, /*alignment=*/0);
656a5f9cda1SChristian Sigg   auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
657a5f9cda1SChristian Sigg                                                builder.getI32IntegerAttr(0));
658e4853be2SMehdi Amini   for (const auto &en : llvm::enumerate(arguments)) {
659a5f9cda1SChristian Sigg     auto index = builder.create<LLVM::ConstantOp>(
660a5f9cda1SChristian Sigg         loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
661a5f9cda1SChristian Sigg     auto fieldPtr = builder.create<LLVM::GEPOp>(
662a5f9cda1SChristian Sigg         loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
663a5f9cda1SChristian Sigg         ArrayRef<Value>{zero, index.getResult()});
664a5f9cda1SChristian Sigg     builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
665a5f9cda1SChristian Sigg     auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
666a5f9cda1SChristian Sigg                                                   arrayPtr, index.getResult());
667a5f9cda1SChristian Sigg     auto casted =
668a5f9cda1SChristian Sigg         builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
669a5f9cda1SChristian Sigg     builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
670a5f9cda1SChristian Sigg   }
671a5f9cda1SChristian Sigg   return arrayPtr;
672a5f9cda1SChristian Sigg }
673a5f9cda1SChristian Sigg 
674a5f9cda1SChristian Sigg // Generates an LLVM IR dialect global that contains the name of the given
675a5f9cda1SChristian Sigg // kernel function as a C string, and returns a pointer to its beginning.
676a5f9cda1SChristian Sigg // The code is essentially:
677a5f9cda1SChristian Sigg //
678a5f9cda1SChristian Sigg // llvm.global constant @kernel_name("function_name\00")
679a5f9cda1SChristian Sigg // func(...) {
680a5f9cda1SChristian Sigg //   %0 = llvm.addressof @kernel_name
681a5f9cda1SChristian Sigg //   %1 = llvm.constant (0 : index)
682a5f9cda1SChristian Sigg //   %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
683a5f9cda1SChristian Sigg // }
generateKernelNameConstant(StringRef moduleName,StringRef name,Location loc,OpBuilder & builder) const684a5f9cda1SChristian Sigg Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
685a5f9cda1SChristian Sigg     StringRef moduleName, StringRef name, Location loc,
686a5f9cda1SChristian Sigg     OpBuilder &builder) const {
687a5f9cda1SChristian Sigg   // Make sure the trailing zero is included in the constant.
688a5f9cda1SChristian Sigg   std::vector<char> kernelName(name.begin(), name.end());
689a5f9cda1SChristian Sigg   kernelName.push_back('\0');
690a5f9cda1SChristian Sigg 
691a5f9cda1SChristian Sigg   std::string globalName =
692a5f9cda1SChristian Sigg       std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
693a5f9cda1SChristian Sigg   return LLVM::createGlobalString(
694a5f9cda1SChristian Sigg       loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
695a5f9cda1SChristian Sigg       LLVM::Linkage::Internal);
696a5f9cda1SChristian Sigg }
697a5f9cda1SChristian Sigg 
698a5f9cda1SChristian Sigg // Emits LLVM IR to launch a kernel function. Expects the module that contains
699a5f9cda1SChristian Sigg // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
700a5f9cda1SChristian Sigg // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
701a5f9cda1SChristian Sigg //
702a5f9cda1SChristian Sigg // %0 = call %binarygetter
703a5f9cda1SChristian Sigg // %1 = call %moduleLoad(%0)
704a5f9cda1SChristian Sigg // %2 = <see generateKernelNameConstant>
705a5f9cda1SChristian Sigg // %3 = call %moduleGetFunction(%1, %2)
706a5f9cda1SChristian Sigg // %4 = call %streamCreate()
707a5f9cda1SChristian Sigg // %5 = <see generateParamsArray>
708a5f9cda1SChristian Sigg // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
709a5f9cda1SChristian Sigg // call %streamSynchronize(%4)
710a5f9cda1SChristian Sigg // call %streamDestroy(%4)
711a5f9cda1SChristian Sigg // call %moduleUnload(%1)
712a5f9cda1SChristian Sigg //
713a5f9cda1SChristian Sigg // If the op is async, the stream corresponds to the (single) async dependency
714a5f9cda1SChristian Sigg // as well as the async token the op produces.
matchAndRewrite(gpu::LaunchFuncOp launchOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const715a5f9cda1SChristian Sigg LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
716ef976337SRiver Riddle     gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
717a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
718ef976337SRiver Riddle   if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
719a5f9cda1SChristian Sigg     return failure();
720a5f9cda1SChristian Sigg 
721a5f9cda1SChristian Sigg   if (launchOp.asyncDependencies().size() > 1)
722a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(
723a5f9cda1SChristian Sigg         launchOp, "Cannot convert with more than one async dependency.");
724a5f9cda1SChristian Sigg 
725a5f9cda1SChristian Sigg   // Fail when the synchronous version of the op has async dependencies. The
726a5f9cda1SChristian Sigg   // lowering destroys the stream, and we do not want to check that there is no
727a5f9cda1SChristian Sigg   // use of the stream after this op.
728a5f9cda1SChristian Sigg   if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
729a5f9cda1SChristian Sigg     return rewriter.notifyMatchFailure(
730a5f9cda1SChristian Sigg         launchOp, "Cannot convert non-async op with async dependencies.");
731a5f9cda1SChristian Sigg 
732a5f9cda1SChristian Sigg   Location loc = launchOp.getLoc();
733a5f9cda1SChristian Sigg 
734a5f9cda1SChristian Sigg   // Create an LLVM global with CUBIN extracted from the kernel annotation and
735a5f9cda1SChristian Sigg   // obtain a pointer to the first byte in it.
736a5f9cda1SChristian Sigg   auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
737a5f9cda1SChristian Sigg       launchOp, launchOp.getKernelModuleName());
738a5f9cda1SChristian Sigg   assert(kernelModule && "expected a kernel module");
739a5f9cda1SChristian Sigg 
740a5f9cda1SChristian Sigg   auto binaryAttr =
741a5f9cda1SChristian Sigg       kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
742a5f9cda1SChristian Sigg   if (!binaryAttr) {
743a5f9cda1SChristian Sigg     kernelModule.emitOpError()
744a5f9cda1SChristian Sigg         << "missing " << gpuBinaryAnnotation << " attribute";
745a5f9cda1SChristian Sigg     return failure();
746a5f9cda1SChristian Sigg   }
747a5f9cda1SChristian Sigg 
748a5f9cda1SChristian Sigg   SmallString<128> nameBuffer(kernelModule.getName());
749a5f9cda1SChristian Sigg   nameBuffer.append(kGpuBinaryStorageSuffix);
750a5f9cda1SChristian Sigg   Value data =
751a5f9cda1SChristian Sigg       LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
752a5f9cda1SChristian Sigg                                binaryAttr.getValue(), LLVM::Linkage::Internal);
753a5f9cda1SChristian Sigg 
754a5f9cda1SChristian Sigg   auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
755a5f9cda1SChristian Sigg   // Get the function from the module. The name corresponds to the name of
756a5f9cda1SChristian Sigg   // the kernel function.
757a5f9cda1SChristian Sigg   auto kernelName = generateKernelNameConstant(
75841d4aa7dSChris Lattner       launchOp.getKernelModuleName().getValue(),
75941d4aa7dSChris Lattner       launchOp.getKernelName().getValue(), loc, rewriter);
760a5f9cda1SChristian Sigg   auto function = moduleGetFunctionCallBuilder.create(
761a5f9cda1SChristian Sigg       loc, rewriter, {module.getResult(0), kernelName});
762a5f9cda1SChristian Sigg   auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
763a5f9cda1SChristian Sigg                                                 rewriter.getI32IntegerAttr(0));
764a5f9cda1SChristian Sigg   Value stream =
765a5f9cda1SChristian Sigg       adaptor.asyncDependencies().empty()
766a5f9cda1SChristian Sigg           ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
767a5f9cda1SChristian Sigg           : adaptor.asyncDependencies().front();
768a5f9cda1SChristian Sigg   // Create array of pointers to kernel arguments.
769ef976337SRiver Riddle   auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
770a5f9cda1SChristian Sigg   auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
77108b63db8SUday Bondhugula   Value dynamicSharedMemorySize = launchOp.dynamicSharedMemorySize()
77208b63db8SUday Bondhugula                                       ? launchOp.dynamicSharedMemorySize()
77308b63db8SUday Bondhugula                                       : zero;
77408b63db8SUday Bondhugula   launchKernelCallBuilder.create(
77508b63db8SUday Bondhugula       loc, rewriter,
77608b63db8SUday Bondhugula       {function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(),
77708b63db8SUday Bondhugula        adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(),
77808b63db8SUday Bondhugula        adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
779a5f9cda1SChristian Sigg        /*extra=*/nullpointer});
780a5f9cda1SChristian Sigg 
781a5f9cda1SChristian Sigg   if (launchOp.asyncToken()) {
782a5f9cda1SChristian Sigg     // Async launch: make dependent ops use the same stream.
783a5f9cda1SChristian Sigg     rewriter.replaceOp(launchOp, {stream});
784a5f9cda1SChristian Sigg   } else {
785a5f9cda1SChristian Sigg     // Synchronize with host and destroy stream. This must be the stream created
786a5f9cda1SChristian Sigg     // above (with no other uses) because we check that the synchronous version
787a5f9cda1SChristian Sigg     // does not have any async dependencies.
788a5f9cda1SChristian Sigg     streamSynchronizeCallBuilder.create(loc, rewriter, stream);
789a5f9cda1SChristian Sigg     streamDestroyCallBuilder.create(loc, rewriter, stream);
790a5f9cda1SChristian Sigg     rewriter.eraseOp(launchOp);
791a5f9cda1SChristian Sigg   }
792a5f9cda1SChristian Sigg   moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
793a5f9cda1SChristian Sigg 
794a5f9cda1SChristian Sigg   return success();
795a5f9cda1SChristian Sigg }
796a5f9cda1SChristian Sigg 
matchAndRewrite(gpu::MemcpyOp memcpyOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const797a5f9cda1SChristian Sigg LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
798ef976337SRiver Riddle     gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
799a5f9cda1SChristian Sigg     ConversionPatternRewriter &rewriter) const {
800a5f9cda1SChristian Sigg   auto memRefType = memcpyOp.src().getType().cast<MemRefType>();
801a5f9cda1SChristian Sigg 
802ef976337SRiver Riddle   if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
803a5f9cda1SChristian Sigg       !isConvertibleAndHasIdentityMaps(memRefType) ||
804a5f9cda1SChristian Sigg       failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
805a5f9cda1SChristian Sigg     return failure();
806a5f9cda1SChristian Sigg 
807a5f9cda1SChristian Sigg   auto loc = memcpyOp.getLoc();
808a5f9cda1SChristian Sigg 
809a5f9cda1SChristian Sigg   MemRefDescriptor srcDesc(adaptor.src());
810361458b1SLoren Maggiore   Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
811a5f9cda1SChristian Sigg 
812a5f9cda1SChristian Sigg   Type elementPtrType = getElementPtrType(memRefType);
813a5f9cda1SChristian Sigg   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
814cafaa350SAlex Zinenko   Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
815cafaa350SAlex Zinenko                                               ArrayRef<Value>{numElements});
816a5f9cda1SChristian Sigg   auto sizeBytes =
817a5f9cda1SChristian Sigg       rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
818a5f9cda1SChristian Sigg 
819a5f9cda1SChristian Sigg   auto src = rewriter.create<LLVM::BitcastOp>(
820a5f9cda1SChristian Sigg       loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc));
821a5f9cda1SChristian Sigg   auto dst = rewriter.create<LLVM::BitcastOp>(
822a5f9cda1SChristian Sigg       loc, llvmPointerType,
823a5f9cda1SChristian Sigg       MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc));
824a5f9cda1SChristian Sigg 
825a5f9cda1SChristian Sigg   auto stream = adaptor.asyncDependencies().front();
826a5f9cda1SChristian Sigg   memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
827a5f9cda1SChristian Sigg 
828a5f9cda1SChristian Sigg   rewriter.replaceOp(memcpyOp, {stream});
829a5f9cda1SChristian Sigg 
830a5f9cda1SChristian Sigg   return success();
831a5f9cda1SChristian Sigg }
832a5f9cda1SChristian Sigg 
matchAndRewrite(gpu::MemsetOp memsetOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const833361458b1SLoren Maggiore LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
834ef976337SRiver Riddle     gpu::MemsetOp memsetOp, OpAdaptor adaptor,
835361458b1SLoren Maggiore     ConversionPatternRewriter &rewriter) const {
836361458b1SLoren Maggiore   auto memRefType = memsetOp.dst().getType().cast<MemRefType>();
837361458b1SLoren Maggiore 
838ef976337SRiver Riddle   if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
839361458b1SLoren Maggiore       !isConvertibleAndHasIdentityMaps(memRefType) ||
840361458b1SLoren Maggiore       failed(isAsyncWithOneDependency(rewriter, memsetOp)))
841361458b1SLoren Maggiore     return failure();
842361458b1SLoren Maggiore 
843361458b1SLoren Maggiore   auto loc = memsetOp.getLoc();
844361458b1SLoren Maggiore 
845361458b1SLoren Maggiore   Type valueType = adaptor.value().getType();
846361458b1SLoren Maggiore   if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
847361458b1SLoren Maggiore     return rewriter.notifyMatchFailure(memsetOp,
848361458b1SLoren Maggiore                                        "value must be a 32 bit scalar");
849361458b1SLoren Maggiore   }
850361458b1SLoren Maggiore 
851361458b1SLoren Maggiore   MemRefDescriptor dstDesc(adaptor.dst());
852361458b1SLoren Maggiore   Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
853361458b1SLoren Maggiore 
854361458b1SLoren Maggiore   auto value =
855361458b1SLoren Maggiore       rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value());
856361458b1SLoren Maggiore   auto dst = rewriter.create<LLVM::BitcastOp>(
857361458b1SLoren Maggiore       loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc));
858361458b1SLoren Maggiore 
859361458b1SLoren Maggiore   auto stream = adaptor.asyncDependencies().front();
860361458b1SLoren Maggiore   memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
861361458b1SLoren Maggiore 
862361458b1SLoren Maggiore   rewriter.replaceOp(memsetOp, {stream});
863361458b1SLoren Maggiore   return success();
864361458b1SLoren Maggiore }
865361458b1SLoren Maggiore 
matchAndRewrite(gpu::SetDefaultDeviceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const86684718d37SKrzysztof Drewniak LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
86784718d37SKrzysztof Drewniak     gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
86884718d37SKrzysztof Drewniak     ConversionPatternRewriter &rewriter) const {
86984718d37SKrzysztof Drewniak   Location loc = op.getLoc();
87084718d37SKrzysztof Drewniak   setDefaultDeviceCallBuilder.create(loc, rewriter, {adaptor.devIndex()});
87184718d37SKrzysztof Drewniak   rewriter.replaceOp(op, {});
87284718d37SKrzysztof Drewniak   return success();
87384718d37SKrzysztof Drewniak }
87484718d37SKrzysztof Drewniak 
875a5f9cda1SChristian Sigg std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createGpuToLLVMConversionPass()876a5f9cda1SChristian Sigg mlir::createGpuToLLVMConversionPass() {
877a5f9cda1SChristian Sigg   return std::make_unique<GpuToLLVMConversionPass>();
878a5f9cda1SChristian Sigg }
8797d855605SButygin 
populateGpuToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns,StringRef gpuBinaryAnnotation)880*b7f93c28SJeff Niu void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
881*b7f93c28SJeff Niu                                                RewritePatternSet &patterns,
8827d855605SButygin                                                StringRef gpuBinaryAnnotation) {
8837d855605SButygin   converter.addConversion(
8847d855605SButygin       [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
8857d855605SButygin         return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
8867d855605SButygin       });
8877d855605SButygin   patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
8887d855605SButygin                ConvertDeallocOpToGpuRuntimeCallPattern,
8897d855605SButygin                ConvertHostRegisterOpToGpuRuntimeCallPattern,
8907d855605SButygin                ConvertMemcpyOpToGpuRuntimeCallPattern,
891361458b1SLoren Maggiore                ConvertMemsetOpToGpuRuntimeCallPattern,
89284718d37SKrzysztof Drewniak                ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
8937d855605SButygin                ConvertWaitAsyncOpToGpuRuntimeCallPattern,
8947d855605SButygin                ConvertWaitOpToGpuRuntimeCallPattern,
8957d855605SButygin                ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
8967d855605SButygin   patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
8977d855605SButygin                                                            gpuBinaryAnnotation);
8987d855605SButygin   patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
8997d855605SButygin }
900