1 //===- ConvertLaunchFuncToGpuRuntimeCalls.cpp - MLIR GPU lowering passes --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert gpu.launch_func op into a sequence of
10 // GPU runtime calls. As most of GPU runtimes does not have a stable published
11 // ABI, this pass uses a slim runtime layer that builds on top of the public
12 // API from GPU runtime headers.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
17
18 #include "../PassDetail.h"
19 #include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
20 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
21 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
22 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
23 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
24 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
25 #include "mlir/Conversion/LLVMCommon/Pattern.h"
26 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
27 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
28 #include "mlir/Dialect/Async/IR/Async.h"
29 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
30 #include "mlir/Dialect/GPU/Transforms/Passes.h"
31 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
32 #include "mlir/IR/Attributes.h"
33 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/BuiltinOps.h"
35 #include "mlir/IR/BuiltinTypes.h"
36
37 #include "llvm/ADT/STLExtras.h"
38 #include "llvm/Support/Error.h"
39 #include "llvm/Support/FormatVariadic.h"
40
41 using namespace mlir;
42
43 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
44
45 namespace {
46
47 class GpuToLLVMConversionPass
48 : public GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
49 public:
50 GpuToLLVMConversionPass() = default;
51
GpuToLLVMConversionPass(const GpuToLLVMConversionPass & other)52 GpuToLLVMConversionPass(const GpuToLLVMConversionPass &other)
53 : GpuToLLVMConversionPassBase(other) {}
54
55 // Run the dialect converter on the module.
56 void runOnOperation() override;
57
58 private:
59 Option<std::string> gpuBinaryAnnotation{
60 *this, "gpu-binary-annotation",
61 llvm::cl::desc("Annotation attribute string for GPU binary"),
62 llvm::cl::init(gpu::getDefaultGpuBinaryAnnotation())};
63 };
64
65 struct FunctionCallBuilder {
FunctionCallBuilder__anon6cc9e7020111::FunctionCallBuilder66 FunctionCallBuilder(StringRef functionName, Type returnType,
67 ArrayRef<Type> argumentTypes)
68 : functionName(functionName),
69 functionType(LLVM::LLVMFunctionType::get(returnType, argumentTypes)) {}
70 LLVM::CallOp create(Location loc, OpBuilder &builder,
71 ArrayRef<Value> arguments) const;
72
73 StringRef functionName;
74 LLVM::LLVMFunctionType functionType;
75 };
76
77 template <typename OpTy>
78 class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
79 public:
ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)80 explicit ConvertOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
81 : ConvertOpToLLVMPattern<OpTy>(typeConverter) {}
82
83 protected:
getNumElements(ConversionPatternRewriter & rewriter,Location loc,MemRefType type,MemRefDescriptor desc) const84 Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
85 MemRefType type, MemRefDescriptor desc) const {
86 return type.hasStaticShape()
87 ? ConvertToLLVMPattern::createIndexConstant(
88 rewriter, loc, type.getNumElements())
89 // For identity maps (verified by caller), the number of
90 // elements is stride[0] * size[0].
91 : rewriter.create<LLVM::MulOp>(loc,
92 desc.stride(rewriter, loc, 0),
93 desc.size(rewriter, loc, 0));
94 }
95
96 MLIRContext *context = &this->getTypeConverter()->getContext();
97
98 Type llvmVoidType = LLVM::LLVMVoidType::get(context);
99 Type llvmPointerType =
100 LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
101 Type llvmPointerPointerType = LLVM::LLVMPointerType::get(llvmPointerType);
102 Type llvmInt8Type = IntegerType::get(context, 8);
103 Type llvmInt32Type = IntegerType::get(context, 32);
104 Type llvmInt64Type = IntegerType::get(context, 64);
105 Type llvmIntPtrType = IntegerType::get(
106 context, this->getTypeConverter()->getPointerBitwidth(0));
107
108 FunctionCallBuilder moduleLoadCallBuilder = {
109 "mgpuModuleLoad",
110 llvmPointerType /* void *module */,
111 {llvmPointerType /* void *cubin */}};
112 FunctionCallBuilder moduleUnloadCallBuilder = {
113 "mgpuModuleUnload", llvmVoidType, {llvmPointerType /* void *module */}};
114 FunctionCallBuilder moduleGetFunctionCallBuilder = {
115 "mgpuModuleGetFunction",
116 llvmPointerType /* void *function */,
117 {
118 llvmPointerType, /* void *module */
119 llvmPointerType /* char *name */
120 }};
121 FunctionCallBuilder launchKernelCallBuilder = {
122 "mgpuLaunchKernel",
123 llvmVoidType,
124 {
125 llvmPointerType, /* void* f */
126 llvmIntPtrType, /* intptr_t gridXDim */
127 llvmIntPtrType, /* intptr_t gridyDim */
128 llvmIntPtrType, /* intptr_t gridZDim */
129 llvmIntPtrType, /* intptr_t blockXDim */
130 llvmIntPtrType, /* intptr_t blockYDim */
131 llvmIntPtrType, /* intptr_t blockZDim */
132 llvmInt32Type, /* unsigned int sharedMemBytes */
133 llvmPointerType, /* void *hstream */
134 llvmPointerPointerType, /* void **kernelParams */
135 llvmPointerPointerType /* void **extra */
136 }};
137 FunctionCallBuilder streamCreateCallBuilder = {
138 "mgpuStreamCreate", llvmPointerType /* void *stream */, {}};
139 FunctionCallBuilder streamDestroyCallBuilder = {
140 "mgpuStreamDestroy", llvmVoidType, {llvmPointerType /* void *stream */}};
141 FunctionCallBuilder streamSynchronizeCallBuilder = {
142 "mgpuStreamSynchronize",
143 llvmVoidType,
144 {llvmPointerType /* void *stream */}};
145 FunctionCallBuilder streamWaitEventCallBuilder = {
146 "mgpuStreamWaitEvent",
147 llvmVoidType,
148 {llvmPointerType /* void *stream */, llvmPointerType /* void *event */}};
149 FunctionCallBuilder eventCreateCallBuilder = {
150 "mgpuEventCreate", llvmPointerType /* void *event */, {}};
151 FunctionCallBuilder eventDestroyCallBuilder = {
152 "mgpuEventDestroy", llvmVoidType, {llvmPointerType /* void *event */}};
153 FunctionCallBuilder eventSynchronizeCallBuilder = {
154 "mgpuEventSynchronize",
155 llvmVoidType,
156 {llvmPointerType /* void *event */}};
157 FunctionCallBuilder eventRecordCallBuilder = {
158 "mgpuEventRecord",
159 llvmVoidType,
160 {llvmPointerType /* void *event */, llvmPointerType /* void *stream */}};
161 FunctionCallBuilder hostRegisterCallBuilder = {
162 "mgpuMemHostRegisterMemRef",
163 llvmVoidType,
164 {llvmIntPtrType /* intptr_t rank */,
165 llvmPointerType /* void *memrefDesc */,
166 llvmIntPtrType /* intptr_t elementSizeBytes */}};
167 FunctionCallBuilder allocCallBuilder = {
168 "mgpuMemAlloc",
169 llvmPointerType /* void * */,
170 {llvmIntPtrType /* intptr_t sizeBytes */,
171 llvmPointerType /* void *stream */}};
172 FunctionCallBuilder deallocCallBuilder = {
173 "mgpuMemFree",
174 llvmVoidType,
175 {llvmPointerType /* void *ptr */, llvmPointerType /* void *stream */}};
176 FunctionCallBuilder memcpyCallBuilder = {
177 "mgpuMemcpy",
178 llvmVoidType,
179 {llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
180 llvmIntPtrType /* intptr_t sizeBytes */,
181 llvmPointerType /* void *stream */}};
182 FunctionCallBuilder memsetCallBuilder = {
183 "mgpuMemset32",
184 llvmVoidType,
185 {llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
186 llvmIntPtrType /* intptr_t sizeBytes */,
187 llvmPointerType /* void *stream */}};
188 FunctionCallBuilder setDefaultDeviceCallBuilder = {
189 "mgpuSetDefaultDevice",
190 llvmVoidType,
191 {llvmInt32Type /* uint32_t devIndex */}};
192 };
193
194 /// A rewrite pattern to convert gpu.host_register operations into a GPU runtime
195 /// call. Currently it supports CUDA and ROCm (HIP).
196 class ConvertHostRegisterOpToGpuRuntimeCallPattern
197 : public ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp> {
198 public:
ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)199 ConvertHostRegisterOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
200 : ConvertOpToGpuRuntimeCallPattern<gpu::HostRegisterOp>(typeConverter) {}
201
202 private:
203 LogicalResult
204 matchAndRewrite(gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
205 ConversionPatternRewriter &rewriter) const override;
206 };
207
208 /// A rewrite pattern to convert gpu.alloc operations into a GPU runtime
209 /// call. Currently it supports CUDA and ROCm (HIP).
210 class ConvertAllocOpToGpuRuntimeCallPattern
211 : public ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp> {
212 public:
ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)213 ConvertAllocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
214 : ConvertOpToGpuRuntimeCallPattern<gpu::AllocOp>(typeConverter) {}
215
216 private:
217 LogicalResult
218 matchAndRewrite(gpu::AllocOp allocOp, OpAdaptor adaptor,
219 ConversionPatternRewriter &rewriter) const override;
220 };
221
222 /// A rewrite pattern to convert gpu.dealloc operations into a GPU runtime
223 /// call. Currently it supports CUDA and ROCm (HIP).
224 class ConvertDeallocOpToGpuRuntimeCallPattern
225 : public ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp> {
226 public:
ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)227 ConvertDeallocOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
228 : ConvertOpToGpuRuntimeCallPattern<gpu::DeallocOp>(typeConverter) {}
229
230 private:
231 LogicalResult
232 matchAndRewrite(gpu::DeallocOp deallocOp, OpAdaptor adaptor,
233 ConversionPatternRewriter &rewriter) const override;
234 };
235
236 class ConvertAsyncYieldToGpuRuntimeCallPattern
237 : public ConvertOpToGpuRuntimeCallPattern<async::YieldOp> {
238 public:
ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)239 ConvertAsyncYieldToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
240 : ConvertOpToGpuRuntimeCallPattern<async::YieldOp>(typeConverter) {}
241
242 private:
243 LogicalResult
244 matchAndRewrite(async::YieldOp yieldOp, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override;
246 };
247
248 /// A rewrite pattern to convert gpu.wait operations into a GPU runtime
249 /// call. Currently it supports CUDA and ROCm (HIP).
250 class ConvertWaitOpToGpuRuntimeCallPattern
251 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
252 public:
ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)253 ConvertWaitOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
254 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
255
256 private:
257 LogicalResult
258 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
259 ConversionPatternRewriter &rewriter) const override;
260 };
261
262 /// A rewrite pattern to convert gpu.wait async operations into a GPU runtime
263 /// call. Currently it supports CUDA and ROCm (HIP).
264 class ConvertWaitAsyncOpToGpuRuntimeCallPattern
265 : public ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp> {
266 public:
ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)267 ConvertWaitAsyncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
268 : ConvertOpToGpuRuntimeCallPattern<gpu::WaitOp>(typeConverter) {}
269
270 private:
271 LogicalResult
272 matchAndRewrite(gpu::WaitOp waitOp, OpAdaptor adaptor,
273 ConversionPatternRewriter &rewriter) const override;
274 };
275
276 /// A rewrite patter to convert gpu.launch_func operations into a sequence of
277 /// GPU runtime calls. Currently it supports CUDA and ROCm (HIP).
278 ///
279 /// In essence, a gpu.launch_func operations gets compiled into the following
280 /// sequence of runtime calls:
281 ///
282 /// * moduleLoad -- loads the module given the cubin / hsaco data
283 /// * moduleGetFunction -- gets a handle to the actual kernel function
284 /// * getStreamHelper -- initializes a new compute stream on GPU
285 /// * launchKernel -- launches the kernel on a stream
286 /// * streamSynchronize -- waits for operations on the stream to finish
287 ///
288 /// Intermediate data structures are allocated on the stack.
289 class ConvertLaunchFuncOpToGpuRuntimeCallPattern
290 : public ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp> {
291 public:
ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter,StringRef gpuBinaryAnnotation)292 ConvertLaunchFuncOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter,
293 StringRef gpuBinaryAnnotation)
294 : ConvertOpToGpuRuntimeCallPattern<gpu::LaunchFuncOp>(typeConverter),
295 gpuBinaryAnnotation(gpuBinaryAnnotation) {}
296
297 private:
298 Value generateParamsArray(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
299 OpBuilder &builder) const;
300 Value generateKernelNameConstant(StringRef moduleName, StringRef name,
301 Location loc, OpBuilder &builder) const;
302
303 LogicalResult
304 matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
305 ConversionPatternRewriter &rewriter) const override;
306
307 llvm::SmallString<32> gpuBinaryAnnotation;
308 };
309
310 class EraseGpuModuleOpPattern : public OpRewritePattern<gpu::GPUModuleOp> {
311 using OpRewritePattern<gpu::GPUModuleOp>::OpRewritePattern;
312
matchAndRewrite(gpu::GPUModuleOp op,PatternRewriter & rewriter) const313 LogicalResult matchAndRewrite(gpu::GPUModuleOp op,
314 PatternRewriter &rewriter) const override {
315 // GPU kernel modules are no longer necessary since we have a global
316 // constant with the CUBIN, or HSACO data.
317 rewriter.eraseOp(op);
318 return success();
319 }
320 };
321
322 /// A rewrite pattern to convert gpu.memcpy operations into a GPU runtime
323 /// call. Currently it supports CUDA and ROCm (HIP).
324 class ConvertMemcpyOpToGpuRuntimeCallPattern
325 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp> {
326 public:
ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)327 ConvertMemcpyOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
328 : ConvertOpToGpuRuntimeCallPattern<gpu::MemcpyOp>(typeConverter) {}
329
330 private:
331 LogicalResult
332 matchAndRewrite(gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
333 ConversionPatternRewriter &rewriter) const override;
334 };
335
336 /// A rewrite pattern to convert gpu.memset operations into a GPU runtime
337 /// call. Currently it supports CUDA and ROCm (HIP).
338 class ConvertMemsetOpToGpuRuntimeCallPattern
339 : public ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp> {
340 public:
ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)341 ConvertMemsetOpToGpuRuntimeCallPattern(LLVMTypeConverter &typeConverter)
342 : ConvertOpToGpuRuntimeCallPattern<gpu::MemsetOp>(typeConverter) {}
343
344 private:
345 LogicalResult
346 matchAndRewrite(gpu::MemsetOp memsetOp, OpAdaptor adaptor,
347 ConversionPatternRewriter &rewriter) const override;
348 };
349
350 /// A rewrite pattern to convert gpu.set_default_device to a GPU runtime call.
351 /// Currently supports CUDA and ROCm (HIP)
352 class ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern
353 : public ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp> {
354 public:
ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(LLVMTypeConverter & typeConverter)355 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern(
356 LLVMTypeConverter &typeConverter)
357 : ConvertOpToGpuRuntimeCallPattern<gpu::SetDefaultDeviceOp>(
358 typeConverter) {}
359
360 LogicalResult
361 matchAndRewrite(gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter) const override;
363 };
364 } // namespace
365
runOnOperation()366 void GpuToLLVMConversionPass::runOnOperation() {
367 LLVMTypeConverter converter(&getContext());
368 RewritePatternSet patterns(&getContext());
369 LLVMConversionTarget target(getContext());
370
371 target.addIllegalDialect<gpu::GPUDialect>();
372
373 mlir::arith::populateArithmeticToLLVMConversionPatterns(converter, patterns);
374 mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
375 populateVectorToLLVMConversionPatterns(converter, patterns);
376 populateMemRefToLLVMConversionPatterns(converter, patterns);
377 populateFuncToLLVMConversionPatterns(converter, patterns);
378 populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
379 target);
380 populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation);
381
382 if (failed(
383 applyPartialConversion(getOperation(), target, std::move(patterns))))
384 signalPassFailure();
385 }
386
create(Location loc,OpBuilder & builder,ArrayRef<Value> arguments) const387 LLVM::CallOp FunctionCallBuilder::create(Location loc, OpBuilder &builder,
388 ArrayRef<Value> arguments) const {
389 auto module = builder.getBlock()->getParent()->getParentOfType<ModuleOp>();
390 auto function = [&] {
391 if (auto function = module.lookupSymbol<LLVM::LLVMFuncOp>(functionName))
392 return function;
393 return OpBuilder::atBlockEnd(module.getBody())
394 .create<LLVM::LLVMFuncOp>(loc, functionName, functionType);
395 }();
396 return builder.create<LLVM::CallOp>(loc, function, arguments);
397 }
398
399 // Returns whether all operands are of LLVM type.
areAllLLVMTypes(Operation * op,ValueRange operands,ConversionPatternRewriter & rewriter)400 static LogicalResult areAllLLVMTypes(Operation *op, ValueRange operands,
401 ConversionPatternRewriter &rewriter) {
402 if (!llvm::all_of(operands, [](Value value) {
403 return LLVM::isCompatibleType(value.getType());
404 }))
405 return rewriter.notifyMatchFailure(
406 op, "Cannot convert if operands aren't of LLVM type.");
407 return success();
408 }
409
410 static LogicalResult
isAsyncWithOneDependency(ConversionPatternRewriter & rewriter,gpu::AsyncOpInterface op)411 isAsyncWithOneDependency(ConversionPatternRewriter &rewriter,
412 gpu::AsyncOpInterface op) {
413 if (op.getAsyncDependencies().size() != 1)
414 return rewriter.notifyMatchFailure(
415 op, "Can only convert with exactly one async dependency.");
416
417 if (!op.getAsyncToken())
418 return rewriter.notifyMatchFailure(op, "Can convert only async version.");
419
420 return success();
421 }
422
matchAndRewrite(gpu::HostRegisterOp hostRegisterOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const423 LogicalResult ConvertHostRegisterOpToGpuRuntimeCallPattern::matchAndRewrite(
424 gpu::HostRegisterOp hostRegisterOp, OpAdaptor adaptor,
425 ConversionPatternRewriter &rewriter) const {
426 auto *op = hostRegisterOp.getOperation();
427 if (failed(areAllLLVMTypes(op, adaptor.getOperands(), rewriter)))
428 return failure();
429
430 Location loc = op->getLoc();
431
432 auto memRefType = hostRegisterOp.value().getType();
433 auto elementType = memRefType.cast<UnrankedMemRefType>().getElementType();
434 auto elementSize = getSizeInBytes(loc, elementType, rewriter);
435
436 auto arguments = getTypeConverter()->promoteOperands(
437 loc, op->getOperands(), adaptor.getOperands(), rewriter);
438 arguments.push_back(elementSize);
439 hostRegisterCallBuilder.create(loc, rewriter, arguments);
440
441 rewriter.eraseOp(op);
442 return success();
443 }
444
matchAndRewrite(gpu::AllocOp allocOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const445 LogicalResult ConvertAllocOpToGpuRuntimeCallPattern::matchAndRewrite(
446 gpu::AllocOp allocOp, OpAdaptor adaptor,
447 ConversionPatternRewriter &rewriter) const {
448 MemRefType memRefType = allocOp.getType();
449
450 if (failed(areAllLLVMTypes(allocOp, adaptor.getOperands(), rewriter)) ||
451 !isConvertibleAndHasIdentityMaps(memRefType) ||
452 failed(isAsyncWithOneDependency(rewriter, allocOp)))
453 return failure();
454
455 auto loc = allocOp.getLoc();
456
457 // Get shape of the memref as values: static sizes are constant
458 // values and dynamic sizes are passed to 'alloc' as operands.
459 SmallVector<Value, 4> shape;
460 SmallVector<Value, 4> strides;
461 Value sizeBytes;
462 getMemRefDescriptorSizes(loc, memRefType, adaptor.dynamicSizes(), rewriter,
463 shape, strides, sizeBytes);
464
465 // Allocate the underlying buffer and store a pointer to it in the MemRef
466 // descriptor.
467 Type elementPtrType = this->getElementPtrType(memRefType);
468 auto stream = adaptor.asyncDependencies().front();
469 Value allocatedPtr =
470 allocCallBuilder.create(loc, rewriter, {sizeBytes, stream}).getResult(0);
471 allocatedPtr =
472 rewriter.create<LLVM::BitcastOp>(loc, elementPtrType, allocatedPtr);
473
474 // No alignment.
475 Value alignedPtr = allocatedPtr;
476
477 // Create the MemRef descriptor.
478 auto memRefDescriptor = this->createMemRefDescriptor(
479 loc, memRefType, allocatedPtr, alignedPtr, shape, strides, rewriter);
480
481 rewriter.replaceOp(allocOp, {memRefDescriptor, stream});
482
483 return success();
484 }
485
matchAndRewrite(gpu::DeallocOp deallocOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const486 LogicalResult ConvertDeallocOpToGpuRuntimeCallPattern::matchAndRewrite(
487 gpu::DeallocOp deallocOp, OpAdaptor adaptor,
488 ConversionPatternRewriter &rewriter) const {
489 if (failed(areAllLLVMTypes(deallocOp, adaptor.getOperands(), rewriter)) ||
490 failed(isAsyncWithOneDependency(rewriter, deallocOp)))
491 return failure();
492
493 Location loc = deallocOp.getLoc();
494
495 Value pointer =
496 MemRefDescriptor(adaptor.memref()).allocatedPtr(rewriter, loc);
497 auto casted = rewriter.create<LLVM::BitcastOp>(loc, llvmPointerType, pointer);
498 Value stream = adaptor.asyncDependencies().front();
499 deallocCallBuilder.create(loc, rewriter, {casted, stream});
500
501 rewriter.replaceOp(deallocOp, {stream});
502 return success();
503 }
504
isGpuAsyncTokenType(Value value)505 static bool isGpuAsyncTokenType(Value value) {
506 return value.getType().isa<gpu::AsyncTokenType>();
507 }
508
509 // Converts !gpu.async.token operands of `async.yield` to runtime calls. The
510 // !gpu.async.token are lowered to stream within the async.execute region, but
511 // are passed as events between them. For each !gpu.async.token operand, we
512 // create an event and record it on the stream.
matchAndRewrite(async::YieldOp yieldOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const513 LogicalResult ConvertAsyncYieldToGpuRuntimeCallPattern::matchAndRewrite(
514 async::YieldOp yieldOp, OpAdaptor adaptor,
515 ConversionPatternRewriter &rewriter) const {
516 if (llvm::none_of(yieldOp.operands(), isGpuAsyncTokenType))
517 return rewriter.notifyMatchFailure(yieldOp, "no gpu async token operand");
518
519 Location loc = yieldOp.getLoc();
520 SmallVector<Value, 4> newOperands(adaptor.getOperands());
521 llvm::SmallDenseSet<Value> streams;
522 for (auto &operand : yieldOp->getOpOperands()) {
523 if (!isGpuAsyncTokenType(operand.get()))
524 continue;
525 auto idx = operand.getOperandNumber();
526 auto stream = adaptor.getOperands()[idx];
527 auto event = eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
528 eventRecordCallBuilder.create(loc, rewriter, {event, stream});
529 newOperands[idx] = event;
530 streams.insert(stream);
531 }
532 for (auto stream : streams)
533 streamDestroyCallBuilder.create(loc, rewriter, {stream});
534
535 rewriter.updateRootInPlace(yieldOp,
536 [&] { yieldOp->setOperands(newOperands); });
537 return success();
538 }
539
540 // Returns whether `value` is the result of an LLVM::CallOp to `functionName`.
isDefinedByCallTo(Value value,StringRef functionName)541 static bool isDefinedByCallTo(Value value, StringRef functionName) {
542 assert(value.getType().isa<LLVM::LLVMPointerType>());
543 if (auto defOp = value.getDefiningOp<LLVM::CallOp>())
544 return defOp.getCallee()->equals(functionName);
545 return false;
546 }
547
548 // Converts `gpu.wait` to runtime calls. The converted op synchronizes the host
549 // with the stream/event operands. The operands are destroyed. That is, it
550 // assumes that it is not used afterwards or elsewhere. Otherwise we will get a
551 // runtime error. Eventually, we should guarantee this property.
matchAndRewrite(gpu::WaitOp waitOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const552 LogicalResult ConvertWaitOpToGpuRuntimeCallPattern::matchAndRewrite(
553 gpu::WaitOp waitOp, OpAdaptor adaptor,
554 ConversionPatternRewriter &rewriter) const {
555 if (waitOp.asyncToken())
556 return rewriter.notifyMatchFailure(waitOp, "Cannot convert async op.");
557
558 Location loc = waitOp.getLoc();
559
560 for (auto operand : adaptor.getOperands()) {
561 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
562 // The converted operand's definition created a stream.
563 streamSynchronizeCallBuilder.create(loc, rewriter, {operand});
564 streamDestroyCallBuilder.create(loc, rewriter, {operand});
565 } else {
566 // Otherwise the converted operand is an event. This assumes that we use
567 // events in control flow code as well.
568 eventSynchronizeCallBuilder.create(loc, rewriter, {operand});
569 eventDestroyCallBuilder.create(loc, rewriter, {operand});
570 }
571 }
572
573 rewriter.eraseOp(waitOp);
574 return success();
575 }
576
577 // Converts `gpu.wait async` to runtime calls. The converted op creates a new
578 // stream that is synchronized with stream/event operands. The operands are
579 // destroyed. That is, it assumes that it is not used afterwards or elsewhere.
580 // Otherwise we will get a runtime error. Eventually, we should guarantee this
581 // property.
matchAndRewrite(gpu::WaitOp waitOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const582 LogicalResult ConvertWaitAsyncOpToGpuRuntimeCallPattern::matchAndRewrite(
583 gpu::WaitOp waitOp, OpAdaptor adaptor,
584 ConversionPatternRewriter &rewriter) const {
585 if (!waitOp.asyncToken())
586 return rewriter.notifyMatchFailure(waitOp, "Can only convert async op.");
587
588 Location loc = waitOp.getLoc();
589
590 auto insertionPoint = rewriter.saveInsertionPoint();
591 SmallVector<Value, 1> events;
592 for (auto pair :
593 llvm::zip(waitOp.asyncDependencies(), adaptor.getOperands())) {
594 auto operand = std::get<1>(pair);
595 if (isDefinedByCallTo(operand, streamCreateCallBuilder.functionName)) {
596 // The converted operand's definition created a stream. Insert an event
597 // into the stream just after the last use of the original token operand.
598 auto *defOp = std::get<0>(pair).getDefiningOp();
599 rewriter.setInsertionPointAfter(defOp);
600 auto event =
601 eventCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
602 eventRecordCallBuilder.create(loc, rewriter, {event, operand});
603 events.push_back(event);
604 } else {
605 // Otherwise the converted operand is an event. This assumes that we use
606 // events in control flow code as well.
607 events.push_back(operand);
608 }
609 }
610 rewriter.restoreInsertionPoint(insertionPoint);
611 auto stream = streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0);
612 for (auto event : events)
613 streamWaitEventCallBuilder.create(loc, rewriter, {stream, event});
614 for (auto event : events)
615 eventDestroyCallBuilder.create(loc, rewriter, {event});
616 rewriter.replaceOp(waitOp, {stream});
617
618 return success();
619 }
620
621 // Creates a struct containing all kernel parameters on the stack and returns
622 // an array of type-erased pointers to the fields of the struct. The array can
623 // then be passed to the CUDA / ROCm (HIP) kernel launch calls.
624 // The generated code is essentially as follows:
625 //
626 // %struct = alloca(sizeof(struct { Parameters... }))
627 // %array = alloca(NumParameters * sizeof(void *))
628 // for (i : [0, NumParameters))
629 // %fieldPtr = llvm.getelementptr %struct[0, i]
630 // llvm.store parameters[i], %fieldPtr
631 // %elementPtr = llvm.getelementptr %array[i]
632 // llvm.store %fieldPtr, %elementPtr
633 // return %array
generateParamsArray(gpu::LaunchFuncOp launchOp,OpAdaptor adaptor,OpBuilder & builder) const634 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateParamsArray(
635 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor, OpBuilder &builder) const {
636 auto loc = launchOp.getLoc();
637 auto numKernelOperands = launchOp.getNumKernelOperands();
638 auto arguments = getTypeConverter()->promoteOperands(
639 loc, launchOp.getOperands().take_back(numKernelOperands),
640 adaptor.getOperands().take_back(numKernelOperands), builder);
641 auto numArguments = arguments.size();
642 SmallVector<Type, 4> argumentTypes;
643 argumentTypes.reserve(numArguments);
644 for (auto argument : arguments)
645 argumentTypes.push_back(argument.getType());
646 auto structType = LLVM::LLVMStructType::getNewIdentified(context, StringRef(),
647 argumentTypes);
648 auto one = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
649 builder.getI32IntegerAttr(1));
650 auto structPtr = builder.create<LLVM::AllocaOp>(
651 loc, LLVM::LLVMPointerType::get(structType), one, /*alignment=*/0);
652 auto arraySize = builder.create<LLVM::ConstantOp>(
653 loc, llvmInt32Type, builder.getI32IntegerAttr(numArguments));
654 auto arrayPtr = builder.create<LLVM::AllocaOp>(loc, llvmPointerPointerType,
655 arraySize, /*alignment=*/0);
656 auto zero = builder.create<LLVM::ConstantOp>(loc, llvmInt32Type,
657 builder.getI32IntegerAttr(0));
658 for (const auto &en : llvm::enumerate(arguments)) {
659 auto index = builder.create<LLVM::ConstantOp>(
660 loc, llvmInt32Type, builder.getI32IntegerAttr(en.index()));
661 auto fieldPtr = builder.create<LLVM::GEPOp>(
662 loc, LLVM::LLVMPointerType::get(argumentTypes[en.index()]), structPtr,
663 ArrayRef<Value>{zero, index.getResult()});
664 builder.create<LLVM::StoreOp>(loc, en.value(), fieldPtr);
665 auto elementPtr = builder.create<LLVM::GEPOp>(loc, llvmPointerPointerType,
666 arrayPtr, index.getResult());
667 auto casted =
668 builder.create<LLVM::BitcastOp>(loc, llvmPointerType, fieldPtr);
669 builder.create<LLVM::StoreOp>(loc, casted, elementPtr);
670 }
671 return arrayPtr;
672 }
673
674 // Generates an LLVM IR dialect global that contains the name of the given
675 // kernel function as a C string, and returns a pointer to its beginning.
676 // The code is essentially:
677 //
678 // llvm.global constant @kernel_name("function_name\00")
679 // func(...) {
680 // %0 = llvm.addressof @kernel_name
681 // %1 = llvm.constant (0 : index)
682 // %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
683 // }
generateKernelNameConstant(StringRef moduleName,StringRef name,Location loc,OpBuilder & builder) const684 Value ConvertLaunchFuncOpToGpuRuntimeCallPattern::generateKernelNameConstant(
685 StringRef moduleName, StringRef name, Location loc,
686 OpBuilder &builder) const {
687 // Make sure the trailing zero is included in the constant.
688 std::vector<char> kernelName(name.begin(), name.end());
689 kernelName.push_back('\0');
690
691 std::string globalName =
692 std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
693 return LLVM::createGlobalString(
694 loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
695 LLVM::Linkage::Internal);
696 }
697
698 // Emits LLVM IR to launch a kernel function. Expects the module that contains
699 // the compiled kernel function as a cubin in the 'nvvm.cubin' attribute, or a
700 // hsaco in the 'rocdl.hsaco' attribute of the kernel function in the IR.
701 //
702 // %0 = call %binarygetter
703 // %1 = call %moduleLoad(%0)
704 // %2 = <see generateKernelNameConstant>
705 // %3 = call %moduleGetFunction(%1, %2)
706 // %4 = call %streamCreate()
707 // %5 = <see generateParamsArray>
708 // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
709 // call %streamSynchronize(%4)
710 // call %streamDestroy(%4)
711 // call %moduleUnload(%1)
712 //
713 // If the op is async, the stream corresponds to the (single) async dependency
714 // as well as the async token the op produces.
matchAndRewrite(gpu::LaunchFuncOp launchOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const715 LogicalResult ConvertLaunchFuncOpToGpuRuntimeCallPattern::matchAndRewrite(
716 gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
717 ConversionPatternRewriter &rewriter) const {
718 if (failed(areAllLLVMTypes(launchOp, adaptor.getOperands(), rewriter)))
719 return failure();
720
721 if (launchOp.asyncDependencies().size() > 1)
722 return rewriter.notifyMatchFailure(
723 launchOp, "Cannot convert with more than one async dependency.");
724
725 // Fail when the synchronous version of the op has async dependencies. The
726 // lowering destroys the stream, and we do not want to check that there is no
727 // use of the stream after this op.
728 if (!launchOp.asyncToken() && !launchOp.asyncDependencies().empty())
729 return rewriter.notifyMatchFailure(
730 launchOp, "Cannot convert non-async op with async dependencies.");
731
732 Location loc = launchOp.getLoc();
733
734 // Create an LLVM global with CUBIN extracted from the kernel annotation and
735 // obtain a pointer to the first byte in it.
736 auto kernelModule = SymbolTable::lookupNearestSymbolFrom<gpu::GPUModuleOp>(
737 launchOp, launchOp.getKernelModuleName());
738 assert(kernelModule && "expected a kernel module");
739
740 auto binaryAttr =
741 kernelModule->getAttrOfType<StringAttr>(gpuBinaryAnnotation);
742 if (!binaryAttr) {
743 kernelModule.emitOpError()
744 << "missing " << gpuBinaryAnnotation << " attribute";
745 return failure();
746 }
747
748 SmallString<128> nameBuffer(kernelModule.getName());
749 nameBuffer.append(kGpuBinaryStorageSuffix);
750 Value data =
751 LLVM::createGlobalString(loc, rewriter, nameBuffer.str(),
752 binaryAttr.getValue(), LLVM::Linkage::Internal);
753
754 auto module = moduleLoadCallBuilder.create(loc, rewriter, data);
755 // Get the function from the module. The name corresponds to the name of
756 // the kernel function.
757 auto kernelName = generateKernelNameConstant(
758 launchOp.getKernelModuleName().getValue(),
759 launchOp.getKernelName().getValue(), loc, rewriter);
760 auto function = moduleGetFunctionCallBuilder.create(
761 loc, rewriter, {module.getResult(0), kernelName});
762 auto zero = rewriter.create<LLVM::ConstantOp>(loc, llvmInt32Type,
763 rewriter.getI32IntegerAttr(0));
764 Value stream =
765 adaptor.asyncDependencies().empty()
766 ? streamCreateCallBuilder.create(loc, rewriter, {}).getResult(0)
767 : adaptor.asyncDependencies().front();
768 // Create array of pointers to kernel arguments.
769 auto kernelParams = generateParamsArray(launchOp, adaptor, rewriter);
770 auto nullpointer = rewriter.create<LLVM::NullOp>(loc, llvmPointerPointerType);
771 Value dynamicSharedMemorySize = launchOp.dynamicSharedMemorySize()
772 ? launchOp.dynamicSharedMemorySize()
773 : zero;
774 launchKernelCallBuilder.create(
775 loc, rewriter,
776 {function.getResult(0), adaptor.gridSizeX(), adaptor.gridSizeY(),
777 adaptor.gridSizeZ(), adaptor.blockSizeX(), adaptor.blockSizeY(),
778 adaptor.blockSizeZ(), dynamicSharedMemorySize, stream, kernelParams,
779 /*extra=*/nullpointer});
780
781 if (launchOp.asyncToken()) {
782 // Async launch: make dependent ops use the same stream.
783 rewriter.replaceOp(launchOp, {stream});
784 } else {
785 // Synchronize with host and destroy stream. This must be the stream created
786 // above (with no other uses) because we check that the synchronous version
787 // does not have any async dependencies.
788 streamSynchronizeCallBuilder.create(loc, rewriter, stream);
789 streamDestroyCallBuilder.create(loc, rewriter, stream);
790 rewriter.eraseOp(launchOp);
791 }
792 moduleUnloadCallBuilder.create(loc, rewriter, module.getResult(0));
793
794 return success();
795 }
796
matchAndRewrite(gpu::MemcpyOp memcpyOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const797 LogicalResult ConvertMemcpyOpToGpuRuntimeCallPattern::matchAndRewrite(
798 gpu::MemcpyOp memcpyOp, OpAdaptor adaptor,
799 ConversionPatternRewriter &rewriter) const {
800 auto memRefType = memcpyOp.src().getType().cast<MemRefType>();
801
802 if (failed(areAllLLVMTypes(memcpyOp, adaptor.getOperands(), rewriter)) ||
803 !isConvertibleAndHasIdentityMaps(memRefType) ||
804 failed(isAsyncWithOneDependency(rewriter, memcpyOp)))
805 return failure();
806
807 auto loc = memcpyOp.getLoc();
808
809 MemRefDescriptor srcDesc(adaptor.src());
810 Value numElements = getNumElements(rewriter, loc, memRefType, srcDesc);
811
812 Type elementPtrType = getElementPtrType(memRefType);
813 Value nullPtr = rewriter.create<LLVM::NullOp>(loc, elementPtrType);
814 Value gepPtr = rewriter.create<LLVM::GEPOp>(loc, elementPtrType, nullPtr,
815 ArrayRef<Value>{numElements});
816 auto sizeBytes =
817 rewriter.create<LLVM::PtrToIntOp>(loc, getIndexType(), gepPtr);
818
819 auto src = rewriter.create<LLVM::BitcastOp>(
820 loc, llvmPointerType, srcDesc.alignedPtr(rewriter, loc));
821 auto dst = rewriter.create<LLVM::BitcastOp>(
822 loc, llvmPointerType,
823 MemRefDescriptor(adaptor.dst()).alignedPtr(rewriter, loc));
824
825 auto stream = adaptor.asyncDependencies().front();
826 memcpyCallBuilder.create(loc, rewriter, {dst, src, sizeBytes, stream});
827
828 rewriter.replaceOp(memcpyOp, {stream});
829
830 return success();
831 }
832
matchAndRewrite(gpu::MemsetOp memsetOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const833 LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
834 gpu::MemsetOp memsetOp, OpAdaptor adaptor,
835 ConversionPatternRewriter &rewriter) const {
836 auto memRefType = memsetOp.dst().getType().cast<MemRefType>();
837
838 if (failed(areAllLLVMTypes(memsetOp, adaptor.getOperands(), rewriter)) ||
839 !isConvertibleAndHasIdentityMaps(memRefType) ||
840 failed(isAsyncWithOneDependency(rewriter, memsetOp)))
841 return failure();
842
843 auto loc = memsetOp.getLoc();
844
845 Type valueType = adaptor.value().getType();
846 if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
847 return rewriter.notifyMatchFailure(memsetOp,
848 "value must be a 32 bit scalar");
849 }
850
851 MemRefDescriptor dstDesc(adaptor.dst());
852 Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);
853
854 auto value =
855 rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.value());
856 auto dst = rewriter.create<LLVM::BitcastOp>(
857 loc, llvmPointerType, dstDesc.alignedPtr(rewriter, loc));
858
859 auto stream = adaptor.asyncDependencies().front();
860 memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
861
862 rewriter.replaceOp(memsetOp, {stream});
863 return success();
864 }
865
matchAndRewrite(gpu::SetDefaultDeviceOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const866 LogicalResult ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern::matchAndRewrite(
867 gpu::SetDefaultDeviceOp op, OpAdaptor adaptor,
868 ConversionPatternRewriter &rewriter) const {
869 Location loc = op.getLoc();
870 setDefaultDeviceCallBuilder.create(loc, rewriter, {adaptor.devIndex()});
871 rewriter.replaceOp(op, {});
872 return success();
873 }
874
875 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createGpuToLLVMConversionPass()876 mlir::createGpuToLLVMConversionPass() {
877 return std::make_unique<GpuToLLVMConversionPass>();
878 }
879
populateGpuToLLVMConversionPatterns(LLVMTypeConverter & converter,RewritePatternSet & patterns,StringRef gpuBinaryAnnotation)880 void mlir::populateGpuToLLVMConversionPatterns(LLVMTypeConverter &converter,
881 RewritePatternSet &patterns,
882 StringRef gpuBinaryAnnotation) {
883 converter.addConversion(
884 [context = &converter.getContext()](gpu::AsyncTokenType type) -> Type {
885 return LLVM::LLVMPointerType::get(IntegerType::get(context, 8));
886 });
887 patterns.add<ConvertAllocOpToGpuRuntimeCallPattern,
888 ConvertDeallocOpToGpuRuntimeCallPattern,
889 ConvertHostRegisterOpToGpuRuntimeCallPattern,
890 ConvertMemcpyOpToGpuRuntimeCallPattern,
891 ConvertMemsetOpToGpuRuntimeCallPattern,
892 ConvertSetDefaultDeviceOpToGpuRuntimeCallPattern,
893 ConvertWaitAsyncOpToGpuRuntimeCallPattern,
894 ConvertWaitOpToGpuRuntimeCallPattern,
895 ConvertAsyncYieldToGpuRuntimeCallPattern>(converter);
896 patterns.add<ConvertLaunchFuncOpToGpuRuntimeCallPattern>(converter,
897 gpuBinaryAnnotation);
898 patterns.add<EraseGpuModuleOpPattern>(&converter.getContext());
899 }
900