1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the NVVM IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The NVVM dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
18
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/DialectImplementation.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Operation.h"
24 #include "mlir/IR/OperationSupport.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 #include "llvm/AsmParser/Parser.h"
27 #include "llvm/IR/Attributes.h"
28 #include "llvm/IR/Function.h"
29 #include "llvm/IR/Type.h"
30 #include "llvm/Support/SourceMgr.h"
31
32 using namespace mlir;
33 using namespace NVVM;
34
35 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc"
36 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc"
37
38 //===----------------------------------------------------------------------===//
39 // Printing/parsing for NVVM ops
40 //===----------------------------------------------------------------------===//
41
printNVVMIntrinsicOp(OpAsmPrinter & p,Operation * op)42 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) {
43 p << " " << op->getOperands();
44 if (op->getNumResults() > 0)
45 p << " : " << op->getResultTypes();
46 }
47
48 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type
parse(OpAsmParser & parser,OperationState & result)49 ParseResult VoteBallotOp::parse(OpAsmParser &parser, OperationState &result) {
50 MLIRContext *context = parser.getContext();
51 auto int32Ty = IntegerType::get(context, 32);
52 auto int1Ty = IntegerType::get(context, 1);
53
54 SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
55 Type type;
56 return failure(parser.parseOperandList(ops) ||
57 parser.parseOptionalAttrDict(result.attributes) ||
58 parser.parseColonType(type) ||
59 parser.addTypeToList(type, result.types) ||
60 parser.resolveOperands(ops, {int32Ty, int1Ty},
61 parser.getNameLoc(), result.operands));
62 }
63
print(OpAsmPrinter & p)64 void VoteBallotOp::print(OpAsmPrinter &p) { printNVVMIntrinsicOp(p, *this); }
65
verify()66 LogicalResult CpAsyncOp::verify() {
67 if (getSize() != 4 && getSize() != 8 && getSize() != 16)
68 return emitError("expected byte size to be either 4, 8 or 16.");
69 if (getBypassL1() && getSize() != 16)
70 return emitError("bypass l1 is only support for 16 bytes copy.");
71 return success();
72 }
73
74 // Given the element type of an operand and whether or not it is an accumulator,
75 // this function returns the PTX type (`NVVM::MMATypes`) that corresponds to the
76 // operand's element type.
inferOperandMMAType(Type operandElType,bool isAccumulator)77 Optional<mlir::NVVM::MMATypes> MmaOp::inferOperandMMAType(Type operandElType,
78 bool isAccumulator) {
79 auto half2Type =
80 LLVM::getFixedVectorType(Float16Type::get(operandElType.getContext()), 2);
81 if (operandElType.isF64())
82 return NVVM::MMATypes::f64;
83 if (operandElType.isF16() || operandElType == half2Type)
84 return NVVM::MMATypes::f16;
85 if (operandElType.isF32() && isAccumulator)
86 return NVVM::MMATypes::f32;
87 if (operandElType.isF32() && !isAccumulator)
88 return NVVM::MMATypes::tf32;
89 if (operandElType.isa<IntegerType>()) {
90 if (isAccumulator)
91 return NVVM::MMATypes::s32;
92 return llvm::None;
93 }
94
95 if (auto structType = operandElType.dyn_cast<LLVM::LLVMStructType>()) {
96 if (structType.getBody().empty())
97 return llvm::None;
98 return inferOperandMMAType(structType.getBody()[0], isAccumulator);
99 }
100
101 return llvm::None;
102 }
103
isInt4PtxType(MMATypes type)104 static bool isInt4PtxType(MMATypes type) {
105 return (type == MMATypes::u4 || type == MMATypes::s4);
106 }
107
isInt8PtxType(MMATypes type)108 static bool isInt8PtxType(MMATypes type) {
109 return (type == MMATypes::u8 || type == MMATypes::s8);
110 }
111
isIntegerPtxType(MMATypes type)112 static bool isIntegerPtxType(MMATypes type) {
113 return isInt4PtxType(type) || isInt8PtxType(type) || type == MMATypes::b1 ||
114 type == MMATypes::s32;
115 }
116
accumPtxType()117 MMATypes MmaOp::accumPtxType() {
118 Optional<mlir::NVVM::MMATypes> val = inferOperandMMAType(
119 getODSOperands(2).getTypes().front(), /*isAccum=*/true);
120 assert(val.has_value() && "accumulator PTX type should always be inferrable");
121 return val.value();
122 }
123
resultPtxType()124 MMATypes MmaOp::resultPtxType() {
125 Optional<mlir::NVVM::MMATypes> val =
126 inferOperandMMAType(getResult().getType(), /*isAccum=*/true);
127 assert(val.has_value() && "result PTX type should always be inferrable");
128 return val.value();
129 }
130
print(OpAsmPrinter & p)131 void MmaOp::print(OpAsmPrinter &p) {
132 SmallVector<Type, 4> regTypes;
133 struct OperandFragment {
134 StringRef operandName;
135 StringRef ptxTypeAttr;
136 SmallVector<Value, 4> regs;
137 explicit OperandFragment(StringRef name, StringRef ptxTypeName)
138 : operandName(name), ptxTypeAttr(ptxTypeName) {}
139 };
140
141 std::array<OperandFragment, 3> frags{
142 OperandFragment("A", getMultiplicandAPtxTypeAttrName()),
143 OperandFragment("B", getMultiplicandBPtxTypeAttrName()),
144 OperandFragment("C", "")};
145 SmallVector<StringRef, 4> ignoreAttrNames{
146 mlir::NVVM::MmaOp::getOperandSegmentSizeAttr()};
147
148 for (unsigned fragIdx = 0; fragIdx < frags.size(); fragIdx++) {
149 auto &frag = frags[fragIdx];
150 auto varOperandSpec = getODSOperandIndexAndLength(fragIdx);
151 for (auto operandIdx = varOperandSpec.first;
152 operandIdx < varOperandSpec.first + varOperandSpec.second;
153 operandIdx++) {
154 frag.regs.push_back(this->getOperand(operandIdx));
155 if (operandIdx == 0) {
156 regTypes.push_back(this->getOperand(operandIdx).getType());
157 }
158 }
159 Optional<MMATypes> inferredType =
160 inferOperandMMAType(regTypes.back(), /*isAccum=*/fragIdx >= 2);
161 if (inferredType)
162 ignoreAttrNames.push_back(frag.ptxTypeAttr);
163 }
164
165 auto printMmaOperand = [&](const OperandFragment &frag) -> void {
166 p << " " << frag.operandName;
167 p << "[";
168 p.printOperands(frag.regs);
169 p << "] ";
170 };
171
172 for (const auto &frag : frags) {
173 printMmaOperand(frag);
174 }
175
176 p.printOptionalAttrDict(this->getOperation()->getAttrs(), ignoreAttrNames);
177
178 // Print the types of the operands and result.
179 p << " : "
180 << "(";
181 llvm::interleaveComma(SmallVector<Type, 3>{frags[0].regs[0].getType(),
182 frags[1].regs[0].getType(),
183 frags[2].regs[0].getType()},
184 p);
185 p << ")";
186 p.printArrowTypeList(TypeRange{this->getRes().getType()});
187 }
188
build(OpBuilder & builder,OperationState & result,Type resultType,ValueRange operandA,ValueRange operandB,ValueRange operandC,ArrayRef<int64_t> shape,Optional<MMAB1Op> b1Op,Optional<MMAIntOverflow> intOverflow,Optional<std::array<MMATypes,2>> multiplicandPtxTypes,Optional<std::array<MMALayout,2>> multiplicandLayouts)189 void MmaOp::build(OpBuilder &builder, OperationState &result, Type resultType,
190 ValueRange operandA, ValueRange operandB, ValueRange operandC,
191 ArrayRef<int64_t> shape, Optional<MMAB1Op> b1Op,
192 Optional<MMAIntOverflow> intOverflow,
193 Optional<std::array<MMATypes, 2>> multiplicandPtxTypes,
194 Optional<std::array<MMALayout, 2>> multiplicandLayouts) {
195
196 assert(shape.size() == 3 && "expected shape to have size 3 (m, n, k)");
197 MLIRContext *ctx = builder.getContext();
198 result.addAttribute(
199 "shape", builder.getAttr<MMAShapeAttr>(shape[0], shape[1], shape[2]));
200
201 result.addOperands(operandA);
202 result.addOperands(operandB);
203 result.addOperands(operandC);
204
205 if (multiplicandPtxTypes) {
206 result.addAttribute("multiplicandAPtxType",
207 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[0]));
208 result.addAttribute("multiplicandBPtxType",
209 MMATypesAttr::get(ctx, (*multiplicandPtxTypes)[1]));
210 } else {
211 if (auto res = inferOperandMMAType(operandA[0].getType(), false))
212 result.addAttribute("multiplicandAPtxType", MMATypesAttr::get(ctx, *res));
213 if (auto res = inferOperandMMAType(operandB[0].getType(), false))
214 result.addAttribute("multiplicandBPtxType", MMATypesAttr::get(ctx, *res));
215 }
216
217 if (multiplicandLayouts) {
218 result.addAttribute("layoutA",
219 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[0]));
220 result.addAttribute("layoutB",
221 MMALayoutAttr::get(ctx, (*multiplicandLayouts)[1]));
222 } else {
223 result.addAttribute("layoutA", MMALayoutAttr::get(ctx, MMALayout::row));
224 result.addAttribute("layoutB", MMALayoutAttr::get(ctx, MMALayout::col));
225 }
226
227 if (intOverflow.has_value())
228 result.addAttribute("intOverflowBehavior",
229 MMAIntOverflowAttr::get(ctx, *intOverflow));
230 if (b1Op.has_value())
231 result.addAttribute("b1Op", MMAB1OpAttr::get(ctx, *b1Op));
232
233 result.addTypes(resultType);
234 result.addAttribute(
235 MmaOp::getOperandSegmentSizeAttr(),
236 builder.getI32VectorAttr({static_cast<int32_t>(operandA.size()),
237 static_cast<int32_t>(operandB.size()),
238 static_cast<int32_t>(operandC.size())}));
239 }
240
241 // <operation> :=
242 // A `[` $operandA `]` B `[` $operandB `]` C `[` $operandC `]`
243 // attr-dict : (type($operandA[0]), type($operandB[0]), type($operandC[0]))
244 // `->` type($res)
parse(OpAsmParser & parser,OperationState & result)245 ParseResult MmaOp::parse(OpAsmParser &parser, OperationState &result) {
246 struct OperandFragment {
247 Optional<MMATypes> elemtype;
248 SmallVector<OpAsmParser::UnresolvedOperand, 4> regs;
249 SmallVector<Type> regTypes;
250 };
251
252 Builder &builder = parser.getBuilder();
253 std::array<OperandFragment, 4> frags;
254
255 NamedAttrList namedAttributes;
256
257 // A helper to parse the operand segments.
258 auto parseMmaOperand = [&](StringRef operandName,
259 OperandFragment &frag) -> LogicalResult {
260 if (parser.parseKeyword(operandName).failed())
261 return failure();
262 if (parser
263 .parseOperandList(frag.regs, OpAsmParser::Delimiter::OptionalSquare)
264 .failed())
265 return failure();
266 return success();
267 };
268
269 // Parse the operand segments.
270 if (parseMmaOperand("A", frags[0]).failed())
271 return failure();
272 if (parseMmaOperand("B", frags[1]).failed())
273 return failure();
274 if (parseMmaOperand("C", frags[2]).failed())
275 return failure();
276
277 if (parser.parseOptionalAttrDict(namedAttributes).failed())
278 return failure();
279
280 // Parse the type specification and resolve operands.
281 SmallVector<Type, 3> operandTypes;
282 if (failed(parser.parseColon()))
283 return failure();
284 if (failed(parser.parseLParen()))
285 return failure();
286 if (failed(parser.parseTypeList(operandTypes)))
287 return failure();
288 if (failed(parser.parseRParen()))
289 if (operandTypes.size() != 3)
290 return parser.emitError(
291 parser.getNameLoc(),
292 "expected one type for each operand segment but got " +
293 Twine(operandTypes.size()) + " types");
294 for (const auto &iter : llvm::enumerate(operandTypes)) {
295 auto &frag = frags[iter.index()];
296 frag.regTypes.resize(frag.regs.size(), iter.value());
297 if (failed(parser.resolveOperands(frag.regs, frag.regTypes,
298 parser.getNameLoc(), result.operands)))
299 return failure();
300 frag.elemtype =
301 inferOperandMMAType(frag.regTypes[0], /*isAccum=*/iter.index() < 2);
302 }
303
304 Type resultType;
305 if (parser.parseArrow() || parser.parseType(resultType))
306 return failure();
307 frags[3].elemtype = inferOperandMMAType(resultType, /*isAccum=*/true);
308
309 std::array<StringRef, 2> names{"multiplicandAPtxType",
310 "multiplicandBPtxType"};
311 for (unsigned idx = 0; idx < names.size(); idx++) {
312 const auto &frag = frags[idx];
313 Optional<NamedAttribute> attr = namedAttributes.getNamed(names[idx]);
314 if (!frag.elemtype.has_value() && !attr.has_value()) {
315 return parser.emitError(
316 parser.getNameLoc(),
317 "attribute " + names[idx] +
318 " is not provided explicitly and cannot be inferred");
319 }
320 if (!attr.has_value())
321 result.addAttribute(
322 names[idx], MMATypesAttr::get(parser.getContext(), *frag.elemtype));
323 }
324
325 result.addTypes(resultType);
326 if (!namedAttributes.empty())
327 result.addAttributes(namedAttributes);
328 result.addAttribute(MmaOp::getOperandSegmentSizeAttr(),
329 builder.getI32VectorAttr({
330 static_cast<int32_t>(frags[0].regs.size()),
331 static_cast<int32_t>(frags[1].regs.size()),
332 static_cast<int32_t>(frags[2].regs.size()),
333 }));
334 return success();
335 }
336
verify()337 LogicalResult MmaOp::verify() {
338 MLIRContext *context = getContext();
339 auto f16Ty = Float16Type::get(context);
340 auto i32Ty = IntegerType::get(context, 32);
341 auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
342 auto f32Ty = Float32Type::get(context);
343 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral(
344 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty});
345
346 auto s32x4StructTy =
347 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty, i32Ty, i32Ty});
348 auto f32x8StructTy =
349 LLVM::LLVMStructType::getLiteral(context, SmallVector<Type>(8, f32Ty));
350 auto f16x2x2StructTy =
351 LLVM::LLVMStructType::getLiteral(context, {f16x2Ty, f16x2Ty});
352 auto f32x4StructTy =
353 LLVM::LLVMStructType::getLiteral(context, {f32Ty, f32Ty, f32Ty, f32Ty});
354 auto s32x2StructTy =
355 LLVM::LLVMStructType::getLiteral(context, {i32Ty, i32Ty});
356
357 std::array<int64_t, 3> mmaShape{getShapeAttr().getM(), getShapeAttr().getN(),
358 getShapeAttr().getK()};
359
360 // These variables define the set of allowed data types for matrices A, B, C,
361 // and result.
362 using AllowedShapes = SmallVector<std::array<int64_t, 3>, 2>;
363 using AllowedTypes = SmallVector<SmallVector<Type, 4>, 2>;
364 AllowedShapes allowedShapes;
365 AllowedTypes expectedA;
366 AllowedTypes expectedB;
367 AllowedTypes expectedC;
368 SmallVector<Type> expectedResult;
369
370 // When M = 16, we just need to calculate the number of 8xk tiles, where
371 // k is a factor that depends on the data type.
372 if (mmaShape[0] == 16) {
373 int64_t kFactor;
374 Type multiplicandFragType;
375 switch (*getMultiplicandAPtxType()) {
376 case MMATypes::tf32:
377 kFactor = 4;
378 multiplicandFragType = i32Ty;
379 expectedResult.push_back(LLVM::LLVMStructType::getLiteral(
380 context, {f32Ty, f32Ty, f32Ty, f32Ty}));
381 break;
382 case MMATypes::f16:
383 case MMATypes::bf16:
384 kFactor = 8;
385 multiplicandFragType = f16x2Ty;
386 expectedResult.push_back(f16x2x2StructTy);
387 expectedResult.push_back(f32x4StructTy);
388 break;
389 case MMATypes::s4:
390 case MMATypes::u4:
391 kFactor = 32;
392 break;
393 case MMATypes::b1:
394 kFactor = 128;
395 break;
396 case MMATypes::s8:
397 case MMATypes::u8:
398 kFactor = 16;
399 break;
400 default:
401 return emitError("invalid shape or multiplicand type: " +
402 stringifyEnum(getMultiplicandAPtxType().value()));
403 }
404
405 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
406 expectedResult.push_back(s32x4StructTy);
407 expectedC.emplace_back(4, i32Ty);
408 multiplicandFragType = i32Ty;
409 } else {
410 expectedC.emplace_back(2, f16x2Ty);
411 expectedC.emplace_back(4, f32Ty);
412 }
413
414 int64_t unitA = (mmaShape[0] / 8) * (mmaShape[2] / kFactor);
415 int64_t unitB = (mmaShape[1] / 8) * (mmaShape[2] / kFactor);
416 expectedA.emplace_back(unitA, multiplicandFragType);
417 expectedB.emplace_back(unitB, multiplicandFragType);
418 allowedShapes.push_back({16, 8, kFactor});
419 allowedShapes.push_back({16, 8, kFactor * 2});
420 }
421
422 // In the M=8 case, there is only 1 possible case per data type.
423 if (mmaShape[0] == 8) {
424 if (*getMultiplicandAPtxType() == MMATypes::f16) {
425 expectedA.emplace_back(2, f16x2Ty);
426 expectedB.emplace_back(2, f16x2Ty);
427 expectedResult.push_back(f16x2x4StructTy);
428 expectedResult.push_back(f32x8StructTy);
429 expectedC.emplace_back(4, f16x2Ty);
430 expectedC.emplace_back(8, f32Ty);
431 allowedShapes.push_back({8, 8, 4});
432 }
433 if (*getMultiplicandAPtxType() == MMATypes::f64) {
434 Type f64Ty = Float64Type::get(context);
435 expectedA.emplace_back(1, f64Ty);
436 expectedB.emplace_back(1, f64Ty);
437 expectedC.emplace_back(2, f64Ty);
438 // expectedC.emplace_back(1, LLVM::getFixedVectorType(f64Ty, 2));
439 expectedResult.emplace_back(LLVM::LLVMStructType::getLiteral(
440 context, SmallVector<Type>(2, f64Ty)));
441 allowedShapes.push_back({8, 8, 4});
442 }
443 if (isIntegerPtxType(getMultiplicandAPtxType().value())) {
444 expectedA.push_back({i32Ty});
445 expectedB.push_back({i32Ty});
446 expectedC.push_back({i32Ty, i32Ty});
447 expectedResult.push_back(s32x2StructTy);
448 if (isInt4PtxType(getMultiplicandAPtxType().value()))
449 allowedShapes.push_back({8, 8, 32});
450 if (isInt8PtxType(getMultiplicandAPtxType().value()))
451 allowedShapes.push_back({8, 8, 16});
452 if (getMultiplicandAPtxType().value() == MMATypes::b1)
453 allowedShapes.push_back({8, 8, 128});
454 }
455 }
456
457 std::string errorMessage;
458 llvm::raw_string_ostream errorStream(errorMessage);
459
460 // Check that we matched an existing shape/dtype combination.
461 if (expectedA.empty() || expectedB.empty() || expectedC.empty() ||
462 !llvm::any_of(allowedShapes,
463 [&](const auto &allowed) { return allowed == mmaShape; })) {
464 errorStream << "unimplemented variant for MMA shape <";
465 llvm::interleaveComma(mmaShape, errorStream);
466 errorStream << ">";
467 return emitOpError(errorMessage);
468 }
469
470 // Verify the operand types for segments of A, B, and C operands.
471 std::array<StringRef, 3> operandNames{"A", "B", "C"};
472 for (const auto &iter : llvm::enumerate(
473 SmallVector<AllowedTypes, 3>{expectedA, expectedB, expectedC})) {
474 auto spec = this->getODSOperandIndexAndLength(iter.index());
475 SmallVector<Type, 4> operandTySeg(operand_type_begin() + spec.first,
476 operand_type_begin() + spec.first +
477 spec.second);
478 bool match =
479 llvm::any_of(iter.value(), [&](const SmallVector<Type, 4> &typeSet) {
480 return typeSet == operandTySeg;
481 });
482
483 if (!match) {
484 errorStream << "Could not match types for the "
485 << operandNames[iter.index()]
486 << " operands; expected one of ";
487 for (const auto &x : iter.value()) {
488 errorStream << x.size() << "x" << x[0] << " ";
489 }
490 errorStream << "but got ";
491 llvm::interleaveComma(operandTySeg, errorStream);
492 return emitOpError(errorStream.str());
493 }
494 }
495
496 // Check the result type
497 if (!llvm::any_of(expectedResult, [&](Type expectedResultType) {
498 return expectedResultType == getResult().getType();
499 })) {
500 errorStream
501 << "Could not match allowed types for the result; expected one of ";
502 llvm::interleaveComma(expectedResult, errorStream);
503 errorStream << " but got " << getResult().getType();
504 return emitOpError(errorStream.str());
505 }
506
507 // Ensure that binary MMA variants have a b1 MMA operation defined.
508 if (getMultiplicandAPtxType() == MMATypes::b1 && !getB1Op()) {
509 return emitOpError("op requires " + getB1OpAttrName().strref() +
510 " attribute");
511 }
512
513 // Ensure int4/int8 MMA variants specify the accum overflow behavior
514 // attribute.
515 if (isInt4PtxType(*getMultiplicandAPtxType()) ||
516 isInt8PtxType(*getMultiplicandAPtxType())) {
517 if (!getIntOverflowBehavior())
518 return emitOpError("op requires " +
519 getIntOverflowBehaviorAttrName().strref() +
520 " attribute");
521 }
522
523 return success();
524 }
525
verify()526 LogicalResult ShflOp::verify() {
527 if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
528 return success();
529 auto type = getType().dyn_cast<LLVM::LLVMStructType>();
530 auto elementType = (type && type.getBody().size() == 2)
531 ? type.getBody()[1].dyn_cast<IntegerType>()
532 : nullptr;
533 if (!elementType || elementType.getWidth() != 1)
534 return emitError("expected return type to be a two-element struct with "
535 "i1 as the second element");
536 return success();
537 }
538
inferMMAType(NVVM::MMATypes type,NVVM::MMAFrag frag,MLIRContext * context)539 std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
540 NVVM::MMAFrag frag,
541 MLIRContext *context) {
542 unsigned numberElements = 0;
543 Type elementType;
544 OpBuilder builder(context);
545 Type f16x2 = VectorType::get(2, builder.getF16Type());
546 if (type == NVVM::MMATypes::f16) {
547 elementType = f16x2;
548 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b)
549 numberElements = 8;
550 else
551 numberElements = 4;
552 } else if (type == NVVM::MMATypes::f32) {
553 elementType = builder.getF32Type();
554 numberElements = 8;
555 } else if (type == NVVM::MMATypes::tf32) {
556 elementType = builder.getI32Type();
557 numberElements = 4;
558 }
559 assert(numberElements != 0 && elementType != nullptr);
560 return std::make_pair(elementType, numberElements);
561 }
562
verify()563 LogicalResult NVVM::WMMALoadOp::verify() {
564 unsigned addressSpace =
565 getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
566 if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
567 return emitOpError("expected source pointer in memory "
568 "space 0, 1, 3");
569
570 if (NVVM::WMMALoadOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
571 getEltype(), getFrag()) == 0)
572 return emitOpError() << "invalid attribute combination";
573 std::pair<Type, unsigned> typeInfo =
574 inferMMAType(getEltype(), getFrag(), getContext());
575 Type dstType = LLVM::LLVMStructType::getLiteral(
576 getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
577 if (getType() != dstType)
578 return emitOpError("expected destination type is a structure of ")
579 << typeInfo.second << " elements of type " << typeInfo.first;
580 return success();
581 }
582
verify()583 LogicalResult NVVM::WMMAStoreOp::verify() {
584 unsigned addressSpace =
585 getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
586 if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
587 return emitOpError("expected operands to be a source pointer in memory "
588 "space 0, 1, 3");
589
590 if (NVVM::WMMAStoreOp::getIntrinsicID(getM(), getN(), getK(), getLayout(),
591 getEltype()) == 0)
592 return emitOpError() << "invalid attribute combination";
593 std::pair<Type, unsigned> typeInfo =
594 inferMMAType(getEltype(), NVVM::MMAFrag::c, getContext());
595 if (getArgs().size() != typeInfo.second)
596 return emitOpError() << "expected " << typeInfo.second << " data operands";
597 if (llvm::any_of(getArgs(), [&typeInfo](Value operands) {
598 return operands.getType() != typeInfo.first;
599 }))
600 return emitOpError() << "expected data operands of type " << typeInfo.first;
601 return success();
602 }
603
verify()604 LogicalResult NVVM::WMMAMmaOp::verify() {
605 if (NVVM::WMMAMmaOp::getIntrinsicID(getM(), getN(), getK(), getLayoutA(),
606 getLayoutB(), getEltypeA(),
607 getEltypeB()) == 0)
608 return emitOpError() << "invalid attribute combination";
609 std::pair<Type, unsigned> typeInfoA =
610 inferMMAType(getEltypeA(), NVVM::MMAFrag::a, getContext());
611 std::pair<Type, unsigned> typeInfoB =
612 inferMMAType(getEltypeA(), NVVM::MMAFrag::b, getContext());
613 std::pair<Type, unsigned> typeInfoC =
614 inferMMAType(getEltypeB(), NVVM::MMAFrag::c, getContext());
615 SmallVector<Type, 32> arguments;
616 arguments.append(typeInfoA.second, typeInfoA.first);
617 arguments.append(typeInfoB.second, typeInfoB.first);
618 arguments.append(typeInfoC.second, typeInfoC.first);
619 unsigned numArgs = arguments.size();
620 if (getArgs().size() != numArgs)
621 return emitOpError() << "expected " << numArgs << " arguments";
622 for (unsigned i = 0; i < numArgs; i++) {
623 if (getArgs()[i].getType() != arguments[i])
624 return emitOpError() << "expected argument " << i << " to be of type "
625 << arguments[i];
626 }
627 Type dstType = LLVM::LLVMStructType::getLiteral(
628 getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
629 if (getType() != dstType)
630 return emitOpError("expected destination type is a structure of ")
631 << typeInfoC.second << " elements of type " << typeInfoC.first;
632 return success();
633 }
634
verify()635 LogicalResult NVVM::LdMatrixOp::verify() {
636 unsigned addressSpace =
637 getPtr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
638 if (addressSpace != 3)
639 return emitOpError("expected source pointer in memory space 3");
640
641 if (getNum() != 1 && getNum() != 2 && getNum() != 4)
642 return emitOpError("expected num attribute to be 1, 2 or 4");
643
644 Type i32 = IntegerType::get(getContext(), 32);
645 if (getNum() == 1 && getType() != i32)
646 return emitOpError("expected destination type is i32");
647 if (getNum() == 2 || getNum() == 4) {
648 Type dstType = LLVM::LLVMStructType::getLiteral(
649 getContext(), SmallVector<Type>(getNum(), i32));
650 if (getType() != dstType)
651 return emitOpError("expected destination type is a structure of ")
652 << getNum() << " elements of type i32";
653 }
654 return success();
655 }
656
657 //===----------------------------------------------------------------------===//
658 // NVVMDialect initialization, type parsing, and registration.
659 //===----------------------------------------------------------------------===//
660
661 // TODO: This should be the llvm.nvvm dialect once this is supported.
initialize()662 void NVVMDialect::initialize() {
663 addOperations<
664 #define GET_OP_LIST
665 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
666 >();
667 addAttributes<
668 #define GET_ATTRDEF_LIST
669 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
670 >();
671
672 // Support unknown operations because not all NVVM operations are
673 // registered.
674 allowUnknownOperations();
675 }
676
verifyOperationAttribute(Operation * op,NamedAttribute attr)677 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op,
678 NamedAttribute attr) {
679 // Kernel function attribute should be attached to functions.
680 if (attr.getName() == NVVMDialect::getKernelFuncAttrName()) {
681 if (!isa<LLVM::LLVMFuncOp>(op)) {
682 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName()
683 << "' attribute attached to unexpected op";
684 }
685 }
686 return success();
687 }
688
689 #define GET_OP_CLASSES
690 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc"
691
692 #define GET_ATTRDEF_CLASSES
693 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc"
694