1 //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the NVGPU dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/NVGPU/NVGPUDialect.h"
14 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
15 #include "mlir/IR/Builders.h"
16 #include "mlir/IR/DialectImplementation.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/TypeUtilities.h"
19 #include "llvm/ADT/TypeSwitch.h"
20 
21 using namespace mlir;
22 using namespace mlir::nvgpu;
23 
24 #include "mlir/Dialect/NVGPU/NVGPUDialect.cpp.inc"
25 
26 void nvgpu::NVGPUDialect::initialize() {
27   addTypes<DeviceAsyncTokenType>();
28   addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/NVGPU/NVGPU.cpp.inc"
31       >();
32 }
33 
34 Type NVGPUDialect::parseType(DialectAsmParser &parser) const {
35   // Parse the main keyword for the type.
36   StringRef keyword;
37   if (parser.parseKeyword(&keyword))
38     return Type();
39   MLIRContext *context = getContext();
40   // Handle 'device async token' types.
41   if (keyword == "device.async.token")
42     return DeviceAsyncTokenType::get(context);
43 
44   parser.emitError(parser.getNameLoc(), "unknown nvgpu type: " + keyword);
45   return Type();
46 }
47 
48 void NVGPUDialect::printType(Type type, DialectAsmPrinter &os) const {
49   TypeSwitch<Type>(type)
50       .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; })
51       .Default([](Type) { llvm_unreachable("unexpected 'nvgpu' type kind"); });
52 }
53 //===----------------------------------------------------------------------===//
54 // NVGPU_DeviceAsyncCopyOp
55 //===----------------------------------------------------------------------===//
56 
57 /// Return true if the last dimension of the MemRefType has unit stride. Also
58 /// return true for memrefs with no strides.
59 static bool isLastMemrefDimUnitStride(MemRefType type) {
60   int64_t offset;
61   SmallVector<int64_t> strides;
62   if (failed(getStridesAndOffset(type, strides, offset))) {
63     return false;
64   }
65   return strides.back() == 1;
66 }
67 
68 LogicalResult DeviceAsyncCopyOp::verify() {
69   auto srcMemref = src().getType().cast<MemRefType>();
70   auto dstMemref = dst().getType().cast<MemRefType>();
71   unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace();
72   if (!isLastMemrefDimUnitStride(srcMemref))
73     return emitError("source memref most minor dim must have unit stride");
74   if (!isLastMemrefDimUnitStride(dstMemref))
75     return emitError("destination memref most minor dim must have unit stride");
76   if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace)
77     return emitError("destination memref must have memory space ")
78            << workgroupAddressSpace;
79   if (dstMemref.getElementType() != srcMemref.getElementType())
80     return emitError("source and destination must have the same element type");
81   if (size_t(srcMemref.getRank()) != srcIndices().size())
82     return emitOpError() << "expected " << srcMemref.getRank()
83                          << " source indices, got " << srcIndices().size();
84   if (size_t(dstMemref.getRank()) != dstIndices().size())
85     return emitOpError() << "expected " << dstMemref.getRank()
86                          << " destination indices, got " << dstIndices().size();
87   return success();
88 }
89 
90 #define GET_OP_CLASSES
91 #include "mlir/Dialect/NVGPU/NVGPU.cpp.inc"
92