1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect.  It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18 
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/OperationSupport.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/AsmParser/Parser.h"
27 #include "llvm/IR/Attributes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/Type.h"
30 #include "llvm/Support/SourceMgr.h"
31 
32 using namespace mlir;
33 using namespace NVVM;
34 
35 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
36 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // Printing/parsing for NVVM ops
40 //===----------------------------------------------------------------------===//
41 
printNVVMIntrinsicOp(OpAsmPrinter & p,Operation * op)42 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
43   p << " " << op->getOperands();
44   if (op->getNumResults() > 0)
45     p << " : " << op->getResultTypes();
46 }
47 
48 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
parse(OpAsmParser & parser,OperationState & result)49 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
50   MLIRContext *context = parser.getContext();
51   auto int32Ty = IntegerType::get(context, 32);
52   auto int1Ty = IntegerType::get(context, 1);
53 
54   SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
55   Type type;
56   return failure(parser.parseOperandList(ops) ||
57                  parser.parseOptionalAttrDict(result.attributes) ||
58                  parser.parseColonType(type) ||
59                  parser.addTypeToList(type, result.types) ||
60                  parser.resolveOperands(ops, {int32Ty, int1Ty},
61                                         parser.getNameLoc(), result.operands));
62 }
63 
print(OpAsmPrinter & p)64 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
65 
verify()66 LogicalResult CpAsyncOp::verify() {
67   if (getSize() != 4 && getSize() != 8 && getSize() != 16)
68     return emitError("expected byte size to be either 4, 8 or 16.");
69   if (getBypassL1() && getSize() != 16)
70     return emitError("bypass l1 is only support for 16 bytes copy.");
71   return success();
72 }
73 
74 // Given the element type of an operand and whether or not it is an accumulator,
75 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
76 // operand's element type.
inferOperandMMAType(Type operandElType,bool isAccumulator)77 Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
78                                                           bool isAccumulator) {
79   auto half2Type =
80       LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
81   if (operandElType.isF64())
82     return NVVM::MMATypes::f64;
83   if (operandElType.isF16() || operandElType == half2Type)
84     return NVVM::MMATypes::f16;
85   if (operandElType.isF32() && isAccumulator)
86     return NVVM::MMATypes::f32;
87   if (operandElType.isF32() && !isAccumulator)
88     return NVVM::MMATypes::tf32;
89   if (operandElType.isa<IntegerType>()) {
90     if (isAccumulator)
91       return NVVM::MMATypes::s32;
92     return llvm::None;
93   }
94 
95   if (auto structType = operandElType.dyn_cast<LLVM::LLVMStructType>()) {
96     if (structType.getBody().empty())
97       return llvm::None;
98     return inferOperandMMAType(structType.getBody()[0], isAccumulator);
99   }
100 
101   return llvm::None;
102 }
103 
isInt4PtxType(MMATypes type)104 static bool isInt4PtxType(MMATypes type) {
105   return (type == MMATypes::u4 || type == MMATypes::s4);
106 }
107 
isInt8PtxType(MMATypes type)108 static bool isInt8PtxType(MMATypes type) {
109   return (type == MMATypes::u8 || type == MMATypes::s8);
110 }
111 
isIntegerPtxType(MMATypes type)112 static bool isIntegerPtxType(MMATypes type) {
113   return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
114          type == MMATypes::s32;
115 }
116 
accumPtxType()117 MMATypes MmaOp::accumPtxType() {
118   Optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
119       getODSOperands(2).getTypes().front(), /*isAccum=*/true);
120   assert(val.has_value() && "accumulator PTX type should always be inferrable");
121   return val.value();
122 }
123 
resultPtxType()124 MMATypes MmaOp::resultPtxType() {
125   Optional<mlir::NVVM::MMATypes> val =
126       inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
127   assert(val.has_value() && "result PTX type should always be inferrable");
128   return val.value();
129 }
130 
print(OpAsmPrinter & p)131 void MmaOp::print(OpAsmPrinter &p) {
132   SmallVector<Type, 4> regTypes;
133   struct OperandFragment {
134     StringRef operandName;
135     StringRef ptxTypeAttr;
136     SmallVector<Value, 4> regs;
137     explicit OperandFragment(StringRef name, StringRef ptxTypeName)
138         : operandName(name), ptxTypeAttr(ptxTypeName) {}
139   };
140 
141   std::array<OperandFragment, 3> frags{
142       OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
143       OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
144       OperandFragment("C", "")};
145   SmallVector<StringRef, 4> ignoreAttrNames{
146       mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
147 
148   for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
149     auto &frag = frags[fragIdx];
150     auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
151     for (auto operandIdx = varOperandSpec.first;
152          operandIdx < varOperandSpec.first + varOperandSpec.second;
153          operandIdx++) {
154       frag.regs.push_back(this->getOperand(operandIdx));
155       if (operandIdx == 0) {
156         regTypes.push_back(this->getOperand(operandIdx).getType());
157       }
158     }
159     Optional<MMATypes> inferredType =
160         inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
161     if (inferredType)
162       ignoreAttrNames.push_back(frag.ptxTypeAttr);
163   }
164 
165   auto printMmaOperand = [&](const OperandFragment &frag) -> void {
166     p << " " << frag.operandName;
167     p << "[";
168     p.printOperands(frag.regs);
169     p << "] ";
170   };
171 
172   for (const auto &frag : frags) {
173     printMmaOperand(frag);
174   }
175 
176   p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
177 
178   // Print the types of the operands and result.
179   p << " : "
180     << "(";
181   llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
182                                              frags[1].regs[0].getType(),
183                                              frags[2].regs[0].getType()},
184                         p);
185   p << ")";
186   p.printArrowTypeList(TypeRange{this->getRes().getType()});
187 }
188 
build(OpBuilder & builder,OperationState & result,Type resultType,ValueRange operandA,ValueRange operandB,ValueRange operandC,ArrayRef<int64_t> shape,Optional<MMAB1Op> b1Op,Optional<MMAIntOverflow> intOverflow,Optional<std::array<MMATypes,2>> multiplicandPtxTypes,Optional<std::array<MMALayout,2>> multiplicandLayouts)189 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
190                   ValueRange operandA, ValueRange operandB, ValueRange operandC,
191                   ArrayRef<int64_t> shape, Optional<MMAB1Op> b1Op,
192                   Optional<MMAIntOverflow> intOverflow,
193                   Optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
194                   Optional<std::array<MMALayout, 2>> multiplicandLayouts) {
195 
196   assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
197   MLIRContext *ctx = builder.getContext();
198   result.addAttribute(
199       "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
200 
201   result.addOperands(operandA);
202   result.addOperands(operandB);
203   result.addOperands(operandC);
204 
205   if (multiplicandPtxTypes) {
206     result.addAttribute("multiplicandAPtxType",
207                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
208     result.addAttribute("multiplicandBPtxType",
209                         MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
210   } else {
211     if (auto res = inferOperandMMAType(operandA[0].getType(), false))
212       result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
213     if (auto res = inferOperandMMAType(operandB[0].getType(), false))
214       result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
215   }
216 
217   if (multiplicandLayouts) {
218     result.addAttribute("layoutA",
219                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
220     result.addAttribute("layoutB",
221                         MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
222   } else {
223     result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
224     result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
225   }
226 
227   if (intOverflow.has_value())
228     result.addAttribute("intOverflowBehavior",
229                         MMAIntOverflowAttr::get(ctx, *intOverflow));
230   if (b1Op.has_value())
231     result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
232 
233   result.addTypes(resultType);
234   result.addAttribute(
235       MmaOp::getOperandSegmentSizeAttr(),
236       builder.getI32VectorAttr({static_cast<int32_t>(operandA.size()),
237                                 static_cast<int32_t>(operandB.size()),
238                                 static_cast<int32_t>(operandC.size())}));
239 }
240 
241 // <operation> :=
242 //   A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
243 //   attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
244 //     `->` type($res)
parse(OpAsmParser & parser,OperationState & result)245 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
246   struct OperandFragment {
247     Optional<MMATypes> elemtype;
248     SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
249     SmallVector<Type> regTypes;
250   };
251 
252   Builder &builder = parser.getBuilder();
253   std::array<OperandFragment, 4> frags;
254 
255   NamedAttrList namedAttributes;
256 
257   // A helper to parse the operand segments.
258   auto parseMmaOperand = [&](StringRef operandName,
259                              OperandFragment &frag) -> LogicalResult {
260     if (parser.parseKeyword(operandName).failed())
261       return failure();
262     if (parser
263             .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
264             .failed())
265       return failure();
266     return success();
267   };
268 
269   // Parse the operand segments.
270   if (parseMmaOperand("A", frags[0]).failed())
271     return failure();
272   if (parseMmaOperand("B", frags[1]).failed())
273     return failure();
274   if (parseMmaOperand("C", frags[2]).failed())
275     return failure();
276 
277   if (parser.parseOptionalAttrDict(namedAttributes).failed())
278     return failure();
279 
280   // Parse the type specification and resolve operands.
281   SmallVector<Type, 3> operandTypes;
282   if (failed(parser.parseColon()))
283     return failure();
284   if (failed(parser.parseLParen()))
285     return failure();
286   if (failed(parser.parseTypeList(operandTypes)))
287     return failure();
288   if (failed(parser.parseRParen()))
289     if (operandTypes.size() != 3)
290       return parser.emitError(
291           parser.getNameLoc(),
292           "expected one type for each operand segment but got " +
293               Twine(operandTypes.size()) + " types");
294   for (const auto &iter : llvm::enumerate(operandTypes)) {
295     auto &frag = frags[iter.index()];
296     frag.regTypes.resize(frag.regs.size(), iter.value());
297     if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
298                                       parser.getNameLoc(), result.operands)))
299       return failure();
300     frag.elemtype =
301         inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
302   }
303 
304   Type resultType;
305   if (parser.parseArrow() || parser.parseType(resultType))
306     return failure();
307   frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
308 
309   std::array<StringRef, 2> names{"multiplicandAPtxType",
310                                  "multiplicandBPtxType"};
311   for (unsigned idx = 0; idx < names.size(); idx++) {
312     const auto &frag = frags[idx];
313     Optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
314     if (!frag.elemtype.has_value() && !attr.has_value()) {
315       return parser.emitError(
316           parser.getNameLoc(),
317           "attribute " + names[idx] +
318               " is not provided explicitly and cannot be inferred");
319     }
320     if (!attr.has_value())
321       result.addAttribute(
322           names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
323   }
324 
325   result.addTypes(resultType);
326   if (!namedAttributes.empty())
327     result.addAttributes(namedAttributes);
328   result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
329                       builder.getI32VectorAttr({
330                           static_cast<int32_t>(frags[0].regs.size()),
331                           static_cast<int32_t>(frags[1].regs.size()),
332                           static_cast<int32_t>(frags[2].regs.size()),
333                       }));
334   return success();
335 }
336 
verify()337 LogicalResult MmaOp::verify() {
338   MLIRContext *context = getContext();
339   auto f16Ty = Float16Type::get(context);
340   auto i32Ty = IntegerType::get(context, 32);
341   auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
342   auto f32Ty = Float32Type::get(context);
343   auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
344       context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
345 
346   auto s32x4StructTy =
347       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
348   auto f32x8StructTy =
349       LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
350   auto f16x2x2StructTy =
351       LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
352   auto f32x4StructTy =
353       LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
354   auto s32x2StructTy =
355       LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
356 
357   std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
358                                   getShapeAttr().getK()};
359 
360   // These variables define the set of allowed data types for matrices A, B, C,
361   // and result.
362   using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
363   using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
364   AllowedShapes allowedShapes;
365   AllowedTypes expectedA;
366   AllowedTypes expectedB;
367   AllowedTypes expectedC;
368   SmallVector<Type> expectedResult;
369 
370   // When M = 16, we just need to calculate the number of 8xk tiles, where
371   // k is a factor that depends on the data type.
372   if (mmaShape[0] == 16) {
373     int64_t kFactor;
374     Type multiplicandFragType;
375     switch (*getMultiplicandAPtxType()) {
376     case MMATypes::tf32:
377       kFactor = 4;
378       multiplicandFragType = i32Ty;
379       expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
380           context, {f32Ty, f32Ty, f32Ty, f32Ty}));
381       break;
382     case MMATypes::f16:
383     case MMATypes::bf16:
384       kFactor = 8;
385       multiplicandFragType = f16x2Ty;
386       expectedResult.push_back(f16x2x2StructTy);
387       expectedResult.push_back(f32x4StructTy);
388       break;
389     case MMATypes::s4:
390     case MMATypes::u4:
391       kFactor = 32;
392       break;
393     case MMATypes::b1:
394       kFactor = 128;
395       break;
396     case MMATypes::s8:
397     case MMATypes::u8:
398       kFactor = 16;
399       break;
400     default:
401       return emitError("invalid shape or multiplicand type: " +
402                        stringifyEnum(getMultiplicandAPtxType().value()));
403     }
404 
405     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
406       expectedResult.push_back(s32x4StructTy);
407       expectedC.emplace_back(4, i32Ty);
408       multiplicandFragType = i32Ty;
409     } else {
410       expectedC.emplace_back(2, f16x2Ty);
411       expectedC.emplace_back(4, f32Ty);
412     }
413 
414     int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
415     int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
416     expectedA.emplace_back(unitA, multiplicandFragType);
417     expectedB.emplace_back(unitB, multiplicandFragType);
418     allowedShapes.push_back({16, 8, kFactor});
419     allowedShapes.push_back({16, 8, kFactor * 2});
420   }
421 
422   // In the M=8 case, there is only 1 possible case per data type.
423   if (mmaShape[0] == 8) {
424     if (*getMultiplicandAPtxType() == MMATypes::f16) {
425       expectedA.emplace_back(2, f16x2Ty);
426       expectedB.emplace_back(2, f16x2Ty);
427       expectedResult.push_back(f16x2x4StructTy);
428       expectedResult.push_back(f32x8StructTy);
429       expectedC.emplace_back(4, f16x2Ty);
430       expectedC.emplace_back(8, f32Ty);
431       allowedShapes.push_back({8, 8, 4});
432     }
433     if (*getMultiplicandAPtxType() == MMATypes::f64) {
434       Type f64Ty = Float64Type::get(context);
435       expectedA.emplace_back(1, f64Ty);
436       expectedB.emplace_back(1, f64Ty);
437       expectedC.emplace_back(2, f64Ty);
438       // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
439       expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
440           context, SmallVector<Type>(2, f64Ty)));
441       allowedShapes.push_back({8, 8, 4});
442     }
443     if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
444       expectedA.push_back({i32Ty});
445       expectedB.push_back({i32Ty});
446       expectedC.push_back({i32Ty, i32Ty});
447       expectedResult.push_back(s32x2StructTy);
448       if (isInt4PtxType(getMultiplicandAPtxType().value()))
449         allowedShapes.push_back({8, 8, 32});
450       if (isInt8PtxType(getMultiplicandAPtxType().value()))
451         allowedShapes.push_back({8, 8, 16});
452       if (getMultiplicandAPtxType().value() == MMATypes::b1)
453         allowedShapes.push_back({8, 8, 128});
454     }
455   }
456 
457   std::string errorMessage;
458   llvm::raw_string_ostream errorStream(errorMessage);
459 
460   // Check that we matched an existing shape/dtype combination.
461   if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
462       !llvm::any_of(allowedShapes,
463                     [&](const auto &allowed) { return allowed == mmaShape; })) {
464     errorStream << "unimplemented variant for MMA shape <";
465     llvm::interleaveComma(mmaShape, errorStream);
466     errorStream << ">";
467     return emitOpError(errorMessage);
468   }
469 
470   // Verify the operand types for segments of A, B, and C operands.
471   std::array<StringRef, 3> operandNames{"A", "B", "C"};
472   for (const auto &iter : llvm::enumerate(
473            SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
474     auto spec = this->getODSOperandIndexAndLength(iter.index());
475     SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
476                                       operand_type_begin() + spec.first +
477                                           spec.second);
478     bool match =
479         llvm::any_of(iter.value(), [&](const SmallVector<Type, 4> &typeSet) {
480           return typeSet == operandTySeg;
481         });
482 
483     if (!match) {
484       errorStream << "Could not match types for the "
485                   << operandNames[iter.index()]
486                   << " operands; expected one of ";
487       for (const auto &x : iter.value()) {
488         errorStream << x.size() << "x" << x[0] << " ";
489       }
490       errorStream << "but got ";
491       llvm::interleaveComma(operandTySeg, errorStream);
492       return emitOpError(errorStream.str());
493     }
494   }
495 
496   // Check the result type
497   if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
498         return expectedResultType == getResult().getType();
499       })) {
500     errorStream
501         << "Could not match allowed types for the result; expected one of ";
502     llvm::interleaveComma(expectedResult, errorStream);
503     errorStream << " but got " << getResult().getType();
504     return emitOpError(errorStream.str());
505   }
506 
507   // Ensure that binary MMA variants have a b1 MMA operation defined.
508   if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
509     return emitOpError("op requires " + getB1OpAttrName().strref() +
510                        " attribute");
511   }
512 
513   // Ensure int4/int8 MMA variants specify the accum overflow behavior
514   // attribute.
515   if (isInt4PtxType(*getMultiplicandAPtxType()) ||
516       isInt8PtxType(*getMultiplicandAPtxType())) {
517     if (!getIntOverflowBehavior())
518       return emitOpError("op requires " +
519                          getIntOverflowBehaviorAttrName().strref() +
520                          " attribute");
521   }
522 
523   return success();
524 }
525 
verify()526 LogicalResult ShflOp::verify() {
527   if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
528     return success();
529   auto type = getType().dyn_cast<LLVM::LLVMStructType>();
530   auto elementType = (type && type.getBody().size() == 2)
531                          ? type.getBody()[1].dyn_cast<IntegerType>()
532                          : nullptr;
533   if (!elementType || elementType.getWidth() != 1)
534     return emitError("expected return type to be a two-element struct with "
535                      "i1 as the second element");
536   return success();
537 }
538 
inferMMAType(NVVM::MMATypes type,NVVM::MMAFrag frag,MLIRContext * context)539 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
540                                                    NVVM::MMAFrag frag,
541                                                    MLIRContext *context) {
542   unsigned numberElements = 0;
543   Type elementType;
544   OpBuilder builder(context);
545   Type f16x2 = VectorType::get(2, builder.getF16Type());
546   if (type == NVVM::MMATypes::f16) {
547     elementType = f16x2;
548     if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
549       numberElements = 8;
550     else
551       numberElements = 4;
552   } else if (type == NVVM::MMATypes::f32) {
553     elementType = builder.getF32Type();
554     numberElements = 8;
555   } else if (type == NVVM::MMATypes::tf32) {
556     elementType = builder.getI32Type();
557     numberElements = 4;
558   }
559   assert(numberElements != 0 && elementType != nullptr);
560   return std::make_pair(elementType, numberElements);
561 }
562 
verify()563 LogicalResult NVVM::WMMALoadOp::verify() {
564   unsigned addressSpace =
565       getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
566   if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
567     return emitOpError("expected source pointer in memory "
568                        "space 0, 1, 3");
569 
570   if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
571                                        getEltype(), getFrag()) == 0)
572     return emitOpError() << "invalid attribute combination";
573   std::pair<Type, unsigned> typeInfo =
574       inferMMAType(getEltype(), getFrag(), getContext());
575   Type dstType = LLVM::LLVMStructType::getLiteral(
576       getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
577   if (getType() != dstType)
578     return emitOpError("expected destination type is a structure of ")
579            << typeInfo.second << " elements of type " << typeInfo.first;
580   return success();
581 }
582 
verify()583 LogicalResult NVVM::WMMAStoreOp::verify() {
584   unsigned addressSpace =
585       getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
586   if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
587     return emitOpError("expected operands to be a source pointer in memory "
588                        "space 0, 1, 3");
589 
590   if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
591                                         getEltype()) == 0)
592     return emitOpError() << "invalid attribute combination";
593   std::pair<Type, unsigned> typeInfo =
594       inferMMAType(getEltype(), NVVM::MMAFrag::c, getContext());
595   if (getArgs().size() != typeInfo.second)
596     return emitOpError() << "expected " << typeInfo.second << " data operands";
597   if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
598         return operands.getType() != typeInfo.first;
599       }))
600     return emitOpError() << "expected data operands of type " << typeInfo.first;
601   return success();
602 }
603 
verify()604 LogicalResult NVVM::WMMAMmaOp::verify() {
605   if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
606                                       getLayoutB(), getEltypeA(),
607                                       getEltypeB()) == 0)
608     return emitOpError() << "invalid attribute combination";
609   std::pair<Type, unsigned> typeInfoA =
610       inferMMAType(getEltypeA(), NVVM::MMAFrag::a, getContext());
611   std::pair<Type, unsigned> typeInfoB =
612       inferMMAType(getEltypeA(), NVVM::MMAFrag::b, getContext());
613   std::pair<Type, unsigned> typeInfoC =
614       inferMMAType(getEltypeB(), NVVM::MMAFrag::c, getContext());
615   SmallVector<Type, 32> arguments;
616   arguments.append(typeInfoA.second, typeInfoA.first);
617   arguments.append(typeInfoB.second, typeInfoB.first);
618   arguments.append(typeInfoC.second, typeInfoC.first);
619   unsigned numArgs = arguments.size();
620   if (getArgs().size() != numArgs)
621     return emitOpError() << "expected " << numArgs << " arguments";
622   for (unsigned i = 0; i < numArgs; i++) {
623     if (getArgs()[i].getType() != arguments[i])
624       return emitOpError() << "expected argument " << i << " to be of type "
625                            << arguments[i];
626   }
627   Type dstType = LLVM::LLVMStructType::getLiteral(
628       getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
629   if (getType() != dstType)
630     return emitOpError("expected destination type is a structure of ")
631            << typeInfoC.second << " elements of type " << typeInfoC.first;
632   return success();
633 }
634 
verify()635 LogicalResult NVVM::LdMatrixOp::verify() {
636   unsigned addressSpace =
637       getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
638   if (addressSpace != 3)
639     return emitOpError("expected source pointer in memory space 3");
640 
641   if (getNum() != 1 && getNum() != 2 && getNum() != 4)
642     return emitOpError("expected num attribute to be 1, 2 or 4");
643 
644   Type i32 = IntegerType::get(getContext(), 32);
645   if (getNum() == 1 && getType() != i32)
646     return emitOpError("expected destination type is i32");
647   if (getNum() == 2 || getNum() == 4) {
648     Type dstType = LLVM::LLVMStructType::getLiteral(
649         getContext(), SmallVector<Type>(getNum(), i32));
650     if (getType() != dstType)
651       return emitOpError("expected destination type is a structure of ")
652              << getNum() << " elements of type i32";
653   }
654   return success();
655 }
656 
657 //===----------------------------------------------------------------------===//
658 // NVVMDialect initialization, type parsing, and registration.
659 //===----------------------------------------------------------------------===//
660 
661 // TODO: This should be the llvm.nvvm dialect once this is supported.
initialize()662 void NVVMDialect::initialize() {
663   addOperations<
664 #define GET_OP_LIST
665 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
666       >();
667   addAttributes<
668 #define GET_ATTRDEF_LIST
669 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
670       >();
671 
672   // Support unknown operations because not all NVVM operations are
673   // registered.
674   allowUnknownOperations();
675 }
676 
verifyOperationAttribute(Operation * op,NamedAttribute attr)677 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
678                                                     NamedAttribute attr) {
679   // Kernel function attribute should be attached to functions.
680   if (attr.getName() == NVVMDialect::getKernelFuncAttrName()) {
681     if (!isa<LLVM::LLVMFuncOp>(op)) {
682       return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
683                              << "' attribute attached to unexpected op";
684     }
685   }
686   return success();
687 }
688 
689 #define GET_OP_CLASSES
690 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
691 
692 #define GET_ATTRDEF_CLASSES
693 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
694