1 //===- NVGPUDialect.cpp - MLIR NVGPU ops implementation -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the NVGPU dialect and its operations. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Dialect/NVGPU/NVGPUDialect.h" 14 #include "mlir/Dialect/GPU/IR/GPUDialect.h" 15 #include "mlir/IR/Builders.h" 16 #include "mlir/IR/DialectImplementation.h" 17 #include "mlir/IR/OpImplementation.h" 18 #include "mlir/IR/TypeUtilities.h" 19 #include "llvm/ADT/TypeSwitch.h" 20 21 using namespace mlir; 22 using namespace mlir::nvgpu; 23 24 #include "mlir/Dialect/NVGPU/NVGPUDialect.cpp.inc" 25 26 void nvgpu::NVGPUDialect::initialize() { 27 addTypes<DeviceAsyncTokenType>(); 28 addOperations< 29 #define GET_OP_LIST 30 #include "mlir/Dialect/NVGPU/NVGPU.cpp.inc" 31 >(); 32 } 33 34 Type NVGPUDialect::parseType(DialectAsmParser &parser) const { 35 // Parse the main keyword for the type. 36 StringRef keyword; 37 if (parser.parseKeyword(&keyword)) 38 return Type(); 39 MLIRContext *context = getContext(); 40 // Handle 'device async token' types. 41 if (keyword == "device.async.token") 42 return DeviceAsyncTokenType::get(context); 43 44 parser.emitError(parser.getNameLoc(), "unknown nvgpu type: " + keyword); 45 return Type(); 46 } 47 48 void NVGPUDialect::printType(Type type, DialectAsmPrinter &os) const { 49 TypeSwitch<Type>(type) 50 .Case<DeviceAsyncTokenType>([&](Type) { os << "device.async.token"; }) 51 .Default([](Type) { llvm_unreachable("unexpected 'nvgpu' type kind"); }); 52 } 53 //===----------------------------------------------------------------------===// 54 // NVGPU_DeviceAsyncCopyOp 55 //===----------------------------------------------------------------------===// 56 57 /// Return true if the last dimension of the MemRefType has unit stride. Also 58 /// return true for memrefs with no strides. 59 static bool isLastMemrefDimUnitStride(MemRefType type) { 60 int64_t offset; 61 SmallVector<int64_t> strides; 62 if (failed(getStridesAndOffset(type, strides, offset))) { 63 return false; 64 } 65 return strides.back() == 1; 66 } 67 68 LogicalResult DeviceAsyncCopyOp::verify() { 69 auto srcMemref = src().getType().cast<MemRefType>(); 70 auto dstMemref = dst().getType().cast<MemRefType>(); 71 unsigned workgroupAddressSpace = gpu::GPUDialect::getWorkgroupAddressSpace(); 72 if (!isLastMemrefDimUnitStride(srcMemref)) 73 return emitError("source memref most minor dim must have unit stride"); 74 if (!isLastMemrefDimUnitStride(dstMemref)) 75 return emitError("destination memref most minor dim must have unit stride"); 76 if (dstMemref.getMemorySpaceAsInt() != workgroupAddressSpace) 77 return emitError("destination memref must have memory space ") 78 << workgroupAddressSpace; 79 if (dstMemref.getElementType() != srcMemref.getElementType()) 80 return emitError("source and destination must have the same element type"); 81 if (size_t(srcMemref.getRank()) != srcIndices().size()) 82 return emitOpError() << "expected " << srcMemref.getRank() 83 << " source indices, got " << srcIndices().size(); 84 if (size_t(dstMemref.getRank()) != dstIndices().size()) 85 return emitOpError() << "expected " << dstMemref.getRank() 86 << " destination indices, got " << dstIndices().size(); 87 return success(); 88 } 89 90 #define GET_OP_CLASSES 91 #include "mlir/Dialect/NVGPU/NVGPU.cpp.inc" 92