1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17 
18 #include "../PassDetail.h"
19 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
21 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
22 #include "mlir/Dialect/Async/IR/Async.h"
23 #include "mlir/Dialect/GPU/GPUDialect.h"
24 #include "mlir/Dialect/GPU/Passes.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/IR/Attributes.h"
27 #include "mlir/IR/Builders.h"
28 #include "mlir/IR/BuiltinOps.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/Support/Error.h"
33 #include "llvm/Support/FormatVariadic.h"
34 
35 using namespace mlir;
36 
37 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
38 
39 namespace {
40 
41 class GpuToLLVMConversionPass
42     : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
43 public:
44   GpuToLLVMConversionPass() = default;
45 
46   GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
47       : GpuToLLVMConversionPassBase(other) {}
48 
49   // Run the dialect converter on the module.
50   void runOnOperation() override;
51 
52 private:
53   Option<std::string> gpuBinaryAnnotation{
54       *this, "gpu-binary-annotation",
55       llvm::cl::desc("Annotation attribute string for GPU binary"),
56       llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
57 };
58 
59 struct FunctionCallBuilder {
60   FunctionCallBuilder(StringRef functionName, Type returnType,
61                       ArrayRef<Type> argumentTypes)
62       : functionName(functionName),
63         functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
64   LLVM::CallOp create(Location loc, OpBuilder &builder,
65                       ArrayRef<Value> arguments) const;
66 
67   StringRef functionName;
68   LLVM::LLVMFunctionType functionType;
69 };
70 
71 template <typename OpTy>
72 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
73 public:
74   explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
75       : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
76 
77 protected:
78   MLIRContext *context = &this->getTypeConverter()->getContext();
79 
80   Type llvmVoidType = LLVM::LLVMVoidType::get(context);
81   Type llvmPointerType =
82       LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
83   Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType);
84   Type llvmInt8Type = IntegerType::get(context, 8);
85   Type llvmInt32Type = IntegerType::get(context, 32);
86   Type llvmInt64Type = IntegerType::get(context, 64);
87   Type llvmIntPtrType = IntegerType::get(
88       context, this->getTypeConverter()->getPointerBitwidth(0));
89 
90   FunctionCallBuilder moduleLoadCallBuilder = {
91       "mgpuModuleLoad",
92       llvmPointerType /* void *module */,
93       {llvmPointerType /* void *cubin */}};
94   FunctionCallBuilder moduleUnloadCallBuilder = {
95       "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
96   FunctionCallBuilder moduleGetFunctionCallBuilder = {
97       "mgpuModuleGetFunction",
98       llvmPointerType /* void *function */,
99       {
100           llvmPointerType, /* void *module */
101           llvmPointerType  /* char *name   */
102       }};
103   FunctionCallBuilder launchKernelCallBuilder = {
104       "mgpuLaunchKernel",
105       llvmVoidType,
106       {
107           llvmPointerType,        /* void* f */
108           llvmIntPtrType,         /* intptr_t gridXDim */
109           llvmIntPtrType,         /* intptr_t gridyDim */
110           llvmIntPtrType,         /* intptr_t gridZDim */
111           llvmIntPtrType,         /* intptr_t blockXDim */
112           llvmIntPtrType,         /* intptr_t blockYDim */
113           llvmIntPtrType,         /* intptr_t blockZDim */
114           llvmInt32Type,          /* unsigned int sharedMemBytes */
115           llvmPointerType,        /* void *hstream */
116           llvmPointerPointerType, /* void **kernelParams */
117           llvmPointerPointerType  /* void **extra */
118       }};
119   FunctionCallBuilder streamCreateCallBuilder = {
120       "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
121   FunctionCallBuilder streamDestroyCallBuilder = {
122       "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
123   FunctionCallBuilder streamSynchronizeCallBuilder = {
124       "mgpuStreamSynchronize",
125       llvmVoidType,
126       {llvmPointerType /* void *stream */}};
127   FunctionCallBuilder streamWaitEventCallBuilder = {
128       "mgpuStreamWaitEvent",
129       llvmVoidType,
130       {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
131   FunctionCallBuilder eventCreateCallBuilder = {
132       "mgpuEventCreate", llvmPointerType /* void *event */, {}};
133   FunctionCallBuilder eventDestroyCallBuilder = {
134       "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
135   FunctionCallBuilder eventSynchronizeCallBuilder = {
136       "mgpuEventSynchronize",
137       llvmVoidType,
138       {llvmPointerType /* void *event */}};
139   FunctionCallBuilder eventRecordCallBuilder = {
140       "mgpuEventRecord",
141       llvmVoidType,
142       {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
143   FunctionCallBuilder hostRegisterCallBuilder = {
144       "mgpuMemHostRegisterMemRef",
145       llvmVoidType,
146       {llvmIntPtrType /* intptr_t rank */,
147        llvmPointerType /* void *memrefDesc */,
148        llvmIntPtrType /* intptr_t elementSizeBytes */}};
149   FunctionCallBuilder allocCallBuilder = {
150       "mgpuMemAlloc",
151       llvmPointerType /* void * */,
152       {llvmIntPtrType /* intptr_t sizeBytes */,
153        llvmPointerType /* void *stream */}};
154   FunctionCallBuilder deallocCallBuilder = {
155       "mgpuMemFree",
156       llvmVoidType,
157       {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
158   FunctionCallBuilder memcpyCallBuilder = {
159       "mgpuMemcpy",
160       llvmVoidType,
161       {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
162        llvmIntPtrType /* intptr_t sizeBytes */,
163        llvmPointerType /* void *stream */}};
164 };
165 
166 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
167 /// call. Currently it supports CUDA and ROCm (HIP).
168 class ConvertHostRegisterOpToGpuRuntimeCallPattern
169     : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
170 public:
171   ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
172       : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
173 
174 private:
175   LogicalResult
176   matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
177                   ConversionPatternRewriter &rewriter) const override;
178 };
179 
180 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
181 /// call. Currently it supports CUDA and ROCm (HIP).
182 class ConvertAllocOpToGpuRuntimeCallPattern
183     : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
184 public:
185   ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
186       : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
187 
188 private:
189   LogicalResult
190   matchAndRewrite(gpu::AllocOp allocOp, ArrayRef<Value> operands,
191                   ConversionPatternRewriter &rewriter) const override;
192 };
193 
194 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
195 /// call. Currently it supports CUDA and ROCm (HIP).
196 class ConvertDeallocOpToGpuRuntimeCallPattern
197     : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
198 public:
199   ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
200       : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
201 
202 private:
203   LogicalResult
204   matchAndRewrite(gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
205                   ConversionPatternRewriter &rewriter) const override;
206 };
207 
208 class ConvertAsyncYieldToGpuRuntimeCallPattern
209     : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
210 public:
211   ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
212       : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
213 
214 private:
215   LogicalResult
216   matchAndRewrite(async::YieldOp yieldOp, ArrayRef<Value> operands,
217                   ConversionPatternRewriter &rewriter) const override;
218 };
219 
220 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
221 /// call. Currently it supports CUDA and ROCm (HIP).
222 class ConvertWaitOpToGpuRuntimeCallPattern
223     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
224 public:
225   ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
226       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
227 
228 private:
229   LogicalResult
230   matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
231                   ConversionPatternRewriter &rewriter) const override;
232 };
233 
234 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
235 /// call. Currently it supports CUDA and ROCm (HIP).
236 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
237     : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
238 public:
239   ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
240       : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
241 
242 private:
243   LogicalResult
244   matchAndRewrite(gpu::WaitOp waitOp, ArrayRef<Value> operands,
245                   ConversionPatternRewriter &rewriter) const override;
246 };
247 
248 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
249 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
250 ///
251 /// In essence, a gpu.launch_func operations gets compiled into the following
252 /// sequence of runtime calls:
253 ///
254 /// * moduleLoad        -- loads the module given the cubin / hsaco data
255 /// * moduleGetFunction -- gets a handle to the actual kernel function
256 /// * getStreamHelper   -- initializes a new compute stream on GPU
257 /// * launchKernel      -- launches the kernel on a stream
258 /// * streamSynchronize -- waits for operations on the stream to finish
259 ///
260 /// Intermediate data structures are allocated on the stack.
261 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
262     : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
263 public:
264   ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
265                                              StringRef gpuBinaryAnnotation)
266       : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
267         gpuBinaryAnnotation(gpuBinaryAnnotation) {}
268 
269 private:
270   Value generateParamsArray(gpu::LaunchFuncOp launchOp,
271                             ArrayRef<Value> operands, OpBuilder &builder) const;
272   Value generateKernelNameConstant(StringRef moduleName, StringRef name,
273                                    Location loc, OpBuilder &builder) const;
274 
275   LogicalResult
276   matchAndRewrite(gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
277                   ConversionPatternRewriter &rewriter) const override;
278 
279   llvm::SmallString<32> gpuBinaryAnnotation;
280 };
281 
282 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
283   using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
284 
285   LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
286                                 PatternRewriter &rewriter) const override {
287     // GPU kernel modules are no longer necessary since we have a global
288     // constant with the CUBIN, or HSACO data.
289     rewriter.eraseOp(op);
290     return success();
291   }
292 };
293 
294 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
295 /// call. Currently it supports CUDA and ROCm (HIP).
296 class ConvertMemcpyOpToGpuRuntimeCallPattern
297     : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
298 public:
299   ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
300       : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
301 
302 private:
303   LogicalResult
304   matchAndRewrite(gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
305                   ConversionPatternRewriter &rewriter) const override;
306 };
307 } // namespace
308 
309 void GpuToLLVMConversionPass::runOnOperation() {
310   LLVMTypeConverter converter(&getContext());
311   RewritePatternSet patterns(&getContext());
312   LLVMConversionTarget target(getContext());
313 
314   target.addIllegalDialect<gpu::GPUDialect>();
315 
316   populateVectorToLLVMConversionPatterns(converter, patterns);
317   populateStdToLLVMConversionPatterns(converter, patterns);
318   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
319                                                     target);
320 
321   converter.addConversion(
322       [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
323         return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
324       });
325   patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
326                ConvertDeallocOpToGpuRuntimeCallPattern,
327                ConvertHostRegisterOpToGpuRuntimeCallPattern,
328                ConvertMemcpyOpToGpuRuntimeCallPattern,
329                ConvertWaitAsyncOpToGpuRuntimeCallPattern,
330                ConvertWaitOpToGpuRuntimeCallPattern,
331                ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
332   patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
333                                                            gpuBinaryAnnotation);
334   patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
335 
336   if (failed(
337           applyPartialConversion(getOperation(), target, std::move(patterns))))
338     signalPassFailure();
339 }
340 
341 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
342                                          ArrayRef<Value> arguments) const {
343   auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
344   auto function = [&] {
345     if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
346       return function;
347     return OpBuilder::atBlockEnd(module.getBody())
348         .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
349   }();
350   return builder.create<LLVM::CallOp>(
351       loc, const_cast<LLVM::LLVMFunctionType &>(functionType).getReturnType(),
352       builder.getSymbolRefAttr(function), arguments);
353 }
354 
355 // Returns whether all operands are of LLVM type.
356 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
357                                      ConversionPatternRewriter &rewriter) {
358   if (!llvm::all_of(operands, [](Value value) {
359         return LLVM::isCompatibleType(value.getType());
360       }))
361     return rewriter.notifyMatchFailure(
362         op, "Cannot convert if operands aren't of LLVM type.");
363   return success();
364 }
365 
366 static LogicalResult
367 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
368                          gpu::AsyncOpInterface op) {
369   if (op.getAsyncDependencies().size() != 1)
370     return rewriter.notifyMatchFailure(
371         op, "Can only convert with exactly one async dependency.");
372 
373   if (!op.getAsyncToken())
374     return rewriter.notifyMatchFailure(op, "Can convert only async version.");
375 
376   return success();
377 }
378 
379 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
380     gpu::HostRegisterOp hostRegisterOp, ArrayRef<Value> operands,
381     ConversionPatternRewriter &rewriter) const {
382   auto *op = hostRegisterOp.getOperation();
383   if (failed(areAllLLVMTypes(op, operands, rewriter)))
384     return failure();
385 
386   Location loc = op->getLoc();
387 
388   auto memRefType = hostRegisterOp.value().getType();
389   auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
390   auto elementSize = getSizeInBytes(loc, elementType, rewriter);
391 
392   auto arguments = getTypeConverter()->promoteOperands(loc, op->getOperands(),
393                                                        operands, rewriter);
394   arguments.push_back(elementSize);
395   hostRegisterCallBuilder.create(loc, rewriter, arguments);
396 
397   rewriter.eraseOp(op);
398   return success();
399 }
400 
401 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
402     gpu::AllocOp allocOp, ArrayRef<Value> operands,
403     ConversionPatternRewriter &rewriter) const {
404   MemRefType memRefType = allocOp.getType();
405 
406   if (failed(areAllLLVMTypes(allocOp, operands, rewriter)) ||
407       !isConvertibleAndHasIdentityMaps(memRefType) ||
408       failed(isAsyncWithOneDependency(rewriter, allocOp)))
409     return failure();
410 
411   auto loc = allocOp.getLoc();
412   auto adaptor = gpu::AllocOpAdaptor(operands, allocOp->getAttrDictionary());
413 
414   // Get shape of the memref as values: static sizes are constant
415   // values and dynamic sizes are passed to 'alloc' as operands.
416   SmallVector<Value, 4> shape;
417   SmallVector<Value, 4> strides;
418   Value sizeBytes;
419   getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter,
420                            shape, strides, sizeBytes);
421 
422   // Allocate the underlying buffer and store a pointer to it in the MemRef
423   // descriptor.
424   Type elementPtrType = this->getElementPtrType(memRefType);
425   auto stream = adaptor.asyncDependencies().front();
426   Value allocatedPtr =
427       allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
428   allocatedPtr =
429       rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
430 
431   // No alignment.
432   Value alignedPtr = allocatedPtr;
433 
434   // Create the MemRef descriptor.
435   auto memRefDescriptor = this->createMemRefDescriptor(
436       loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
437 
438   rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
439 
440   return success();
441 }
442 
443 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
444     gpu::DeallocOp deallocOp, ArrayRef<Value> operands,
445     ConversionPatternRewriter &rewriter) const {
446   if (failed(areAllLLVMTypes(deallocOp, operands, rewriter)) ||
447       failed(isAsyncWithOneDependency(rewriter, deallocOp)))
448     return failure();
449 
450   Location loc = deallocOp.getLoc();
451 
452   auto adaptor =
453       gpu::DeallocOpAdaptor(operands, deallocOp->getAttrDictionary());
454   Value pointer =
455       MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
456   auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
457   Value stream = adaptor.asyncDependencies().front();
458   deallocCallBuilder.create(loc, rewriter, {casted, stream});
459 
460   rewriter.replaceOp(deallocOp, {stream});
461   return success();
462 }
463 
464 static bool isGpuAsyncTokenType(Value value) {
465   return value.getType().isa<gpu::AsyncTokenType>();
466 }
467 
468 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
469 // !gpu.async.token are lowered to stream within the async.execute region, but
470 // are passed as events between them. For each !gpu.async.token operand, we
471 // create an event and record it on the stream.
472 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
473     async::YieldOp yieldOp, ArrayRef<Value> operands,
474     ConversionPatternRewriter &rewriter) const {
475   if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType))
476     return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
477 
478   Location loc = yieldOp.getLoc();
479   SmallVector<Value, 4> newOperands(operands.begin(), operands.end());
480   llvm::SmallDenseSet<Value> streams;
481   for (auto &operand : yieldOp->getOpOperands()) {
482     if (!isGpuAsyncTokenType(operand.get()))
483       continue;
484     auto idx = operand.getOperandNumber();
485     auto stream = operands[idx];
486     auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
487     eventRecordCallBuilder.create(loc, rewriter, {event, stream});
488     newOperands[idx] = event;
489     streams.insert(stream);
490   }
491   for (auto stream : streams)
492     streamDestroyCallBuilder.create(loc, rewriter, {stream});
493 
494   rewriter.updateRootInPlace(yieldOp,
495                              [&] { yieldOp->setOperands(newOperands); });
496   return success();
497 }
498 
499 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
500 static bool isDefinedByCallTo(Value value, StringRef functionName) {
501   assert(value.getType().isa<LLVM::LLVMPointerType>());
502   if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
503     return defOp.callee()->equals(functionName);
504   return false;
505 }
506 
507 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
508 // with the stream/event operands. The operands are destroyed. That is, it
509 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
510 // runtime error. Eventually, we should guarantee this property.
511 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
512     gpu::WaitOp waitOp, ArrayRef<Value> operands,
513     ConversionPatternRewriter &rewriter) const {
514   if (waitOp.asyncToken())
515     return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
516 
517   Location loc = waitOp.getLoc();
518 
519   for (auto operand : operands) {
520     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
521       // The converted operand's definition created a stream.
522       streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
523       streamDestroyCallBuilder.create(loc, rewriter, {operand});
524     } else {
525       // Otherwise the converted operand is an event. This assumes that we use
526       // events in control flow code as well.
527       eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
528       eventDestroyCallBuilder.create(loc, rewriter, {operand});
529     }
530   }
531 
532   rewriter.eraseOp(waitOp);
533   return success();
534 }
535 
536 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
537 // stream that is synchronized with stream/event operands. The operands are
538 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
539 // Otherwise we will get a runtime error. Eventually, we should guarantee this
540 // property.
541 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
542     gpu::WaitOp waitOp, ArrayRef<Value> operands,
543     ConversionPatternRewriter &rewriter) const {
544   if (!waitOp.asyncToken())
545     return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
546 
547   Location loc = waitOp.getLoc();
548 
549   auto insertionPoint = rewriter.saveInsertionPoint();
550   SmallVector<Value, 1> events;
551   for (auto pair : llvm::zip(waitOp.asyncDependencies(), operands)) {
552     auto operand = std::get<1>(pair);
553     if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
554       // The converted operand's definition created a stream. Insert an event
555       // into the stream just after the last use of the original token operand.
556       auto *defOp = std::get<0>(pair).getDefiningOp();
557       rewriter.setInsertionPointAfter(defOp);
558       auto event =
559           eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
560       eventRecordCallBuilder.create(loc, rewriter, {event, operand});
561       events.push_back(event);
562     } else {
563       // Otherwise the converted operand is an event. This assumes that we use
564       // events in control flow code as well.
565       events.push_back(operand);
566     }
567   }
568   rewriter.restoreInsertionPoint(insertionPoint);
569   auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
570   for (auto event : events)
571     streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
572   for (auto event : events)
573     eventDestroyCallBuilder.create(loc, rewriter, {event});
574   rewriter.replaceOp(waitOp, {stream});
575 
576   return success();
577 }
578 
579 // Creates a struct containing all kernel parameters on the stack and returns
580 // an array of type-erased pointers to the fields of the struct. The array can
581 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
582 // The generated code is essentially as follows:
583 //
584 // %struct = alloca(sizeof(struct { Parameters... }))
585 // %array = alloca(NumParameters * sizeof(void *))
586 // for (i : [0, NumParameters))
587 //   %fieldPtr = llvm.getelementptr %struct[0, i]
588 //   llvm.store parameters[i], %fieldPtr
589 //   %elementPtr = llvm.getelementptr %array[i]
590 //   llvm.store %fieldPtr, %elementPtr
591 // return %array
592 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
593     gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
594     OpBuilder &builder) const {
595   auto loc = launchOp.getLoc();
596   auto numKernelOperands = launchOp.getNumKernelOperands();
597   auto arguments = getTypeConverter()->promoteOperands(
598       loc, launchOp.getOperands().take_back(numKernelOperands),
599       operands.take_back(numKernelOperands), builder);
600   auto numArguments = arguments.size();
601   SmallVector<Type, 4> argumentTypes;
602   argumentTypes.reserve(numArguments);
603   for (auto argument : arguments)
604     argumentTypes.push_back(argument.getType());
605   auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
606                                                            argumentTypes);
607   auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
608                                               builder.getI32IntegerAttr(1));
609   auto structPtr = builder.create<LLVM::AllocaOp>(
610       loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
611   auto arraySize = builder.create<LLVM::ConstantOp>(
612       loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
613   auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
614                                                  arraySize, /*alignment=*/0);
615   auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
616                                                builder.getI32IntegerAttr(0));
617   for (auto en : llvm::enumerate(arguments)) {
618     auto index = builder.create<LLVM::ConstantOp>(
619         loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
620     auto fieldPtr = builder.create<LLVM::GEPOp>(
621         loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
622         ArrayRef<Value>{zero, index.getResult()});
623     builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
624     auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
625                                                   arrayPtr, index.getResult());
626     auto casted =
627         builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
628     builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
629   }
630   return arrayPtr;
631 }
632 
633 // Generates an LLVM IR dialect global that contains the name of the given
634 // kernel function as a C string, and returns a pointer to its beginning.
635 // The code is essentially:
636 //
637 // llvm.global constant @kernel_name("function_name\00")
638 // func(...) {
639 //   %0 = llvm.addressof @kernel_name
640 //   %1 = llvm.constant (0 : index)
641 //   %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
642 // }
643 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
644     StringRef moduleName, StringRef name, Location loc,
645     OpBuilder &builder) const {
646   // Make sure the trailing zero is included in the constant.
647   std::vector<char> kernelName(name.begin(), name.end());
648   kernelName.push_back('\0');
649 
650   std::string globalName =
651       std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
652   return LLVM::createGlobalString(
653       loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
654       LLVM::Linkage::Internal);
655 }
656 
657 // Emits LLVM IR to launch a kernel function. Expects the module that contains
658 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
659 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
660 //
661 // %0 = call %binarygetter
662 // %1 = call %moduleLoad(%0)
663 // %2 = <see generateKernelNameConstant>
664 // %3 = call %moduleGetFunction(%1, %2)
665 // %4 = call %streamCreate()
666 // %5 = <see generateParamsArray>
667 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
668 // call %streamSynchronize(%4)
669 // call %streamDestroy(%4)
670 // call %moduleUnload(%1)
671 //
672 // If the op is async, the stream corresponds to the (single) async dependency
673 // as well as the async token the op produces.
674 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
675     gpu::LaunchFuncOp launchOp, ArrayRef<Value> operands,
676     ConversionPatternRewriter &rewriter) const {
677   if (failed(areAllLLVMTypes(launchOp, operands, rewriter)))
678     return failure();
679 
680   if (launchOp.asyncDependencies().size() > 1)
681     return rewriter.notifyMatchFailure(
682         launchOp, "Cannot convert with more than one async dependency.");
683 
684   // Fail when the synchronous version of the op has async dependencies. The
685   // lowering destroys the stream, and we do not want to check that there is no
686   // use of the stream after this op.
687   if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
688     return rewriter.notifyMatchFailure(
689         launchOp, "Cannot convert non-async op with async dependencies.");
690 
691   Location loc = launchOp.getLoc();
692 
693   // Create an LLVM global with CUBIN extracted from the kernel annotation and
694   // obtain a pointer to the first byte in it.
695   auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
696       launchOp, launchOp.getKernelModuleName());
697   assert(kernelModule && "expected a kernel module");
698 
699   auto binaryAttr =
700       kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
701   if (!binaryAttr) {
702     kernelModule.emitOpError()
703         << "missing " << gpuBinaryAnnotation << " attribute";
704     return failure();
705   }
706 
707   SmallString<128> nameBuffer(kernelModule.getName());
708   nameBuffer.append(kGpuBinaryStorageSuffix);
709   Value data =
710       LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
711                                binaryAttr.getValue(), LLVM::Linkage::Internal);
712 
713   auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
714   // Get the function from the module. The name corresponds to the name of
715   // the kernel function.
716   auto kernelName = generateKernelNameConstant(
717       launchOp.getKernelModuleName(), launchOp.getKernelName(), loc, rewriter);
718   auto function = moduleGetFunctionCallBuilder.create(
719       loc, rewriter, {module.getResult(0), kernelName});
720   auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
721                                                 rewriter.getI32IntegerAttr(0));
722   auto adaptor =
723       gpu::LaunchFuncOpAdaptor(operands, launchOp->getAttrDictionary());
724   Value stream =
725       adaptor.asyncDependencies().empty()
726           ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
727           : adaptor.asyncDependencies().front();
728   // Create array of pointers to kernel arguments.
729   auto kernelParams = generateParamsArray(launchOp, operands, rewriter);
730   auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
731   launchKernelCallBuilder.create(loc, rewriter,
732                                  {function.getResult(0), launchOp.gridSizeX(),
733                                   launchOp.gridSizeY(), launchOp.gridSizeZ(),
734                                   launchOp.blockSizeX(), launchOp.blockSizeY(),
735                                   launchOp.blockSizeZ(),
736                                   /*sharedMemBytes=*/zero, stream, kernelParams,
737                                   /*extra=*/nullpointer});
738 
739   if (launchOp.asyncToken()) {
740     // Async launch: make dependent ops use the same stream.
741     rewriter.replaceOp(launchOp, {stream});
742   } else {
743     // Synchronize with host and destroy stream. This must be the stream created
744     // above (with no other uses) because we check that the synchronous version
745     // does not have any async dependencies.
746     streamSynchronizeCallBuilder.create(loc, rewriter, stream);
747     streamDestroyCallBuilder.create(loc, rewriter, stream);
748     rewriter.eraseOp(launchOp);
749   }
750   moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
751 
752   return success();
753 }
754 
755 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
756     gpu::MemcpyOp memcpyOp, ArrayRef<Value> operands,
757     ConversionPatternRewriter &rewriter) const {
758   auto memRefType = memcpyOp.src().getType().cast<MemRefType>();
759 
760   if (failed(areAllLLVMTypes(memcpyOp, operands, rewriter)) ||
761       !isConvertibleAndHasIdentityMaps(memRefType) ||
762       failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
763     return failure();
764 
765   auto loc = memcpyOp.getLoc();
766   auto adaptor = gpu::MemcpyOpAdaptor(operands, memcpyOp->getAttrDictionary());
767 
768   MemRefDescriptor srcDesc(adaptor.src());
769 
770   Value numElements =
771       memRefType.hasStaticShape()
772           ? createIndexConstant(rewriter, loc, memRefType.getNumElements())
773           // For identity layouts (verified above), the number of elements is
774           // stride[0] * size[0].
775           : rewriter.create<LLVM::MulOp>(loc, srcDesc.stride(rewriter, loc, 0),
776                                          srcDesc.size(rewriter, loc, 0));
777 
778   Type elementPtrType = getElementPtrType(memRefType);
779   Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
780   Value gepPtr = rewriter.create<LLVM::GEPOp>(
781       loc, elementPtrType, ArrayRef<Value>{nullPtr, numElements});
782   auto sizeBytes =
783       rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
784 
785   auto src = rewriter.create<LLVM::BitcastOp>(
786       loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc));
787   auto dst = rewriter.create<LLVM::BitcastOp>(
788       loc, llvmPointerType,
789       MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc));
790 
791   auto stream = adaptor.asyncDependencies().front();
792   memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
793 
794   rewriter.replaceOp(memcpyOp, {stream});
795 
796   return success();
797 }
798 
799 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
800 mlir::createGpuToLLVMConversionPass() {
801   return std::make_unique<GpuToLLVMConversionPass>();
802 }
803