1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/OperationSupport.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/AsmParser/Parser.h"
27 #include "llvm/IR/Attributes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/Type.h"
30 #include "llvm/Support/SourceMgr.h"
31 
32 using namespace mlir;
33 using namespace NVVM;
34 
35 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
36 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // Printing/parsing for NVVM ops
40 //===----------------------------------------------------------------------===//
41 
42 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
43   p << " " << op->getOperands();
44   if (op->getNumResults() > 0)
45     p << " : " << op->getResultTypes();
46 }
47 
48 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
49 static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
50                                          OperationState &result) {
51   MLIRContext *context = parser.getContext();
52   auto int32Ty = IntegerType::get(context, 32);
53   auto int1Ty = IntegerType::get(context, 1);
54 
55   SmallVector<OpAsmParser::OperandType, 8> ops;
56   Type type;
57   return failure(parser.parseOperandList(ops) ||
58                  parser.parseOptionalAttrDict(result.attributes) ||
59                  parser.parseColonType(type) ||
60                  parser.addTypeToList(type, result.types) ||
61                  parser.resolveOperands(ops, {int32Ty, int1Ty},
62                                         parser.getNameLoc(), result.operands));
63 }
64 
65 static LogicalResult verify(MmaOp op) {
66   MLIRContext *context = op.getContext();
67   auto f16Ty = Float16Type::get(context);
68   auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
69   auto f32Ty = Float32Type::get(context);
70   auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
71       context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
72   auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
73       context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
74 
75   SmallVector<Type, 12> operandTypes(op.getOperandTypes().begin(),
76                                      op.getOperandTypes().end());
77   if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
78       operandTypes != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
79                                             f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
80                                             f32Ty, f32Ty, f32Ty}) {
81     return op.emitOpError(
82         "expected operands to be 4 <halfx2>s followed by either "
83         "4 <halfx2>s or 8 floats");
84   }
85   if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) {
86     return op.emitOpError("expected result type to be a struct of either 4 "
87                           "<halfx2>s or 8 floats");
88   }
89 
90   auto alayout = op->getAttrOfType<StringAttr>("alayout");
91   auto blayout = op->getAttrOfType<StringAttr>("blayout");
92 
93   if (!(alayout && blayout) ||
94       !(alayout.getValue() == "row" || alayout.getValue() == "col") ||
95       !(blayout.getValue() == "row" || blayout.getValue() == "col")) {
96     return op.emitOpError(
97         "alayout and blayout attributes must be set to either "
98         "\"row\" or \"col\"");
99   }
100 
101   if (operandTypes == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
102                                             f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
103                                             f32Ty, f32Ty, f32Ty} &&
104       op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
105       blayout.getValue() == "col") {
106     return success();
107   }
108   return op.emitOpError("unimplemented mma.sync variant");
109 }
110 
111 std::pair<mlir::Type, unsigned>
112 inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) {
113   unsigned numberElements = 0;
114   Type elementType;
115   OpBuilder builder(context);
116   Type f16x2 = VectorType::get(2, builder.getF16Type());
117   if (type == NVVM::MMATypes::f16) {
118     elementType = f16x2;
119     if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
120       numberElements = 8;
121     else
122       numberElements = 4;
123   } else if (type == NVVM::MMATypes::f32) {
124     elementType = builder.getF32Type();
125     numberElements = 8;
126   } else if (type == NVVM::MMATypes::tf32) {
127     elementType = builder.getI32Type();
128     numberElements = 4;
129   }
130   assert(numberElements != 0 && elementType != nullptr);
131   return std::make_pair(elementType, numberElements);
132 }
133 
134 static LogicalResult verify(NVVM::WMMALoadOp op) {
135   unsigned addressSpace =
136       op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
137   if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
138     return op.emitOpError("expected source pointer in memory "
139                           "space 0, 1, 3");
140 
141   if (NVVM::WMMALoadOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(),
142                                        op.eltype(), op.frag()) == 0)
143     return op.emitOpError() << "invalid attribute combination";
144   std::pair<Type, unsigned> typeInfo =
145       inferMMAType(op.eltype(), op.frag(), op.getContext());
146   Type dstType = LLVM::LLVMStructType::getLiteral(
147       op.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
148   if (op.getType() != dstType)
149     return op.emitOpError("expected destination type is a structure of ")
150            << typeInfo.second << " elements of type " << typeInfo.first;
151   return success();
152 }
153 
154 static LogicalResult verify(NVVM::WMMAStoreOp op) {
155   unsigned addressSpace =
156       op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
157   if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
158     return op.emitOpError("expected operands to be a source pointer in memory "
159                           "space 0, 1, 3");
160 
161   if (NVVM::WMMAStoreOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(),
162                                         op.eltype()) == 0)
163     return op.emitOpError() << "invalid attribute combination";
164   std::pair<Type, unsigned> typeInfo =
165       inferMMAType(op.eltype(), NVVM::MMAFrag::c, op.getContext());
166   if (op.args().size() != typeInfo.second)
167     return op.emitOpError()
168            << "expected " << typeInfo.second << " data operands";
169   if (llvm::any_of(op.args(), [&typeInfo](Value operands) {
170         return operands.getType() != typeInfo.first;
171       }))
172     return op.emitOpError()
173            << "expected data operands of type " << typeInfo.first;
174   return success();
175 }
176 
177 static LogicalResult verify(NVVM::WMMAMmaOp op) {
178   if (NVVM::WMMAMmaOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layoutA(),
179                                       op.layoutB(), op.eltypeA(),
180                                       op.eltypeB()) == 0)
181     return op.emitOpError() << "invalid attribute combination";
182   std::pair<Type, unsigned> typeInfoA =
183       inferMMAType(op.eltypeA(), NVVM::MMAFrag::a, op.getContext());
184   std::pair<Type, unsigned> typeInfoB =
185       inferMMAType(op.eltypeA(), NVVM::MMAFrag::b, op.getContext());
186   std::pair<Type, unsigned> typeInfoC =
187       inferMMAType(op.eltypeB(), NVVM::MMAFrag::c, op.getContext());
188   SmallVector<Type, 32> arguments;
189   arguments.append(typeInfoA.second, typeInfoA.first);
190   arguments.append(typeInfoB.second, typeInfoB.first);
191   arguments.append(typeInfoC.second, typeInfoC.first);
192   unsigned numArgs = arguments.size();
193   if (op.args().size() != numArgs)
194     return op.emitOpError() << "expected " << numArgs << " arguments";
195   for (unsigned i = 0; i < numArgs; i++) {
196     if (op.args()[i].getType() != arguments[i])
197       return op.emitOpError()
198              << "expected argument " << i << " to be of type " << arguments[i];
199   }
200   Type dstType = LLVM::LLVMStructType::getLiteral(
201       op.getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
202   if (op.getType() != dstType)
203     return op.emitOpError("expected destination type is a structure of ")
204            << typeInfoC.second << " elements of type " << typeInfoC.first;
205   return success();
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // NVVMDialect initialization, type parsing, and registration.
210 //===----------------------------------------------------------------------===//
211 
212 // TODO: This should be the llvm.nvvm dialect once this is supported.
213 void NVVMDialect::initialize() {
214   addOperations<
215 #define GET_OP_LIST
216 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
217       >();
218   addAttributes<
219 #define GET_ATTRDEF_LIST
220 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
221       >();
222 
223   // Support unknown operations because not all NVVM operations are
224   // registered.
225   allowUnknownOperations();
226 }
227 
228 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
229                                                     NamedAttribute attr) {
230   // Kernel function attribute should be attached to functions.
231   if (attr.getName() == NVVMDialect::getKernelFuncAttrName()) {
232     if (!isa<LLVM::LLVMFuncOp>(op)) {
233       return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
234                              << "' attribute attached to unexpected op";
235     }
236   }
237   return success();
238 }
239 
240 #define GET_OP_CLASSES
241 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
242 
243 #define GET_ATTRDEF_CLASSES
244 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
245