1 //===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines the types and operation details for the NVVM IR dialect in 10 // MLIR, and the LLVM IR dialect. It also registers the dialect. 11 // 12 // The NVVM dialect only contains GPU specific additions on top of the general 13 // LLVM dialect. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/Dialect/LLVMIR/NVVMDialect.h" 18 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/BuiltinTypes.h" 21 #include "mlir/IR/DialectImplementation.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/Operation.h" 24 #include "mlir/IR/OperationSupport.h" 25 #include "llvm/ADT/TypeSwitch.h" 26 #include "llvm/AsmParser/Parser.h" 27 #include "llvm/IR/Attributes.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/Type.h" 30 #include "llvm/Support/SourceMgr.h" 31 32 using namespace mlir; 33 using namespace NVVM; 34 35 #include "mlir/Dialect/LLVMIR/NVVMOpsDialect.cpp.inc" 36 #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.cpp.inc" 37 38 //===----------------------------------------------------------------------===// 39 // Printing/parsing for NVVM ops 40 //===----------------------------------------------------------------------===// 41 42 static void printNVVMIntrinsicOp(OpAsmPrinter &p, Operation *op) { 43 p << " " << op->getOperands(); 44 if (op->getNumResults() > 0) 45 p << " : " << op->getResultTypes(); 46 } 47 48 // <operation> ::= `llvm.nvvm.vote.ballot.sync %mask, %pred` : result_type 49 static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser, 50 OperationState &result) { 51 MLIRContext *context = parser.getContext(); 52 auto int32Ty = IntegerType::get(context, 32); 53 auto int1Ty = IntegerType::get(context, 1); 54 55 SmallVector<OpAsmParser::OperandType, 8> ops; 56 Type type; 57 return failure(parser.parseOperandList(ops) || 58 parser.parseOptionalAttrDict(result.attributes) || 59 parser.parseColonType(type) || 60 parser.addTypeToList(type, result.types) || 61 parser.resolveOperands(ops, {int32Ty, int1Ty}, 62 parser.getNameLoc(), result.operands)); 63 } 64 65 static LogicalResult verify(MmaOp op) { 66 MLIRContext *context = op.getContext(); 67 auto f16Ty = Float16Type::get(context); 68 auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); 69 auto f32Ty = Float32Type::get(context); 70 auto f16x2x4StructTy = LLVM::LLVMStructType::getLiteral( 71 context, {f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty}); 72 auto f32x8StructTy = LLVM::LLVMStructType::getLiteral( 73 context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); 74 75 SmallVector<Type, 12> operandTypes(op.getOperandTypes().begin(), 76 op.getOperandTypes().end()); 77 if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) && 78 operandTypes != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, 79 f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, 80 f32Ty, f32Ty, f32Ty}) { 81 return op.emitOpError( 82 "expected operands to be 4 <halfx2>s followed by either " 83 "4 <halfx2>s or 8 floats"); 84 } 85 if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) { 86 return op.emitOpError("expected result type to be a struct of either 4 " 87 "<halfx2>s or 8 floats"); 88 } 89 90 auto alayout = op->getAttrOfType<StringAttr>("alayout"); 91 auto blayout = op->getAttrOfType<StringAttr>("blayout"); 92 93 if (!(alayout && blayout) || 94 !(alayout.getValue() == "row" || alayout.getValue() == "col") || 95 !(blayout.getValue() == "row" || blayout.getValue() == "col")) { 96 return op.emitOpError( 97 "alayout and blayout attributes must be set to either " 98 "\"row\" or \"col\""); 99 } 100 101 if (operandTypes == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, 102 f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, 103 f32Ty, f32Ty, f32Ty} && 104 op.getType() == f32x8StructTy && alayout.getValue() == "row" && 105 blayout.getValue() == "col") { 106 return success(); 107 } 108 return op.emitOpError("unimplemented mma.sync variant"); 109 } 110 111 std::pair<mlir::Type, unsigned> 112 inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) { 113 unsigned numberElements = 0; 114 Type elementType; 115 OpBuilder builder(context); 116 Type f16x2 = VectorType::get(2, builder.getF16Type()); 117 if (type == NVVM::MMATypes::f16) { 118 elementType = f16x2; 119 if (frag == NVVM::MMAFrag::a || frag == NVVM::MMAFrag::b) 120 numberElements = 8; 121 else 122 numberElements = 4; 123 } else if (type == NVVM::MMATypes::f32) { 124 elementType = builder.getF32Type(); 125 numberElements = 8; 126 } else if (type == NVVM::MMATypes::tf32) { 127 elementType = builder.getI32Type(); 128 numberElements = 4; 129 } 130 assert(numberElements != 0 && elementType != nullptr); 131 return std::make_pair(elementType, numberElements); 132 } 133 134 static LogicalResult verify(NVVM::WMMALoadOp op) { 135 unsigned addressSpace = 136 op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace(); 137 if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) 138 return op.emitOpError("expected source pointer in memory " 139 "space 0, 1, 3"); 140 141 if (NVVM::WMMALoadOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(), 142 op.eltype(), op.frag()) == 0) 143 return op.emitOpError() << "invalid attribute combination"; 144 std::pair<Type, unsigned> typeInfo = 145 inferMMAType(op.eltype(), op.frag(), op.getContext()); 146 Type dstType = LLVM::LLVMStructType::getLiteral( 147 op.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first)); 148 if (op.getType() != dstType) 149 return op.emitOpError("expected destination type is a structure of ") 150 << typeInfo.second << " elements of type " << typeInfo.first; 151 return success(); 152 } 153 154 static LogicalResult verify(NVVM::WMMAStoreOp op) { 155 unsigned addressSpace = 156 op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace(); 157 if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) 158 return op.emitOpError("expected operands to be a source pointer in memory " 159 "space 0, 1, 3"); 160 161 if (NVVM::WMMAStoreOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(), 162 op.eltype()) == 0) 163 return op.emitOpError() << "invalid attribute combination"; 164 std::pair<Type, unsigned> typeInfo = 165 inferMMAType(op.eltype(), NVVM::MMAFrag::c, op.getContext()); 166 if (op.args().size() != typeInfo.second) 167 return op.emitOpError() 168 << "expected " << typeInfo.second << " data operands"; 169 if (llvm::any_of(op.args(), [&typeInfo](Value operands) { 170 return operands.getType() != typeInfo.first; 171 })) 172 return op.emitOpError() 173 << "expected data operands of type " << typeInfo.first; 174 return success(); 175 } 176 177 static LogicalResult verify(NVVM::WMMAMmaOp op) { 178 if (NVVM::WMMAMmaOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layoutA(), 179 op.layoutB(), op.eltypeA(), 180 op.eltypeB()) == 0) 181 return op.emitOpError() << "invalid attribute combination"; 182 std::pair<Type, unsigned> typeInfoA = 183 inferMMAType(op.eltypeA(), NVVM::MMAFrag::a, op.getContext()); 184 std::pair<Type, unsigned> typeInfoB = 185 inferMMAType(op.eltypeA(), NVVM::MMAFrag::b, op.getContext()); 186 std::pair<Type, unsigned> typeInfoC = 187 inferMMAType(op.eltypeB(), NVVM::MMAFrag::c, op.getContext()); 188 SmallVector<Type, 32> arguments; 189 arguments.append(typeInfoA.second, typeInfoA.first); 190 arguments.append(typeInfoB.second, typeInfoB.first); 191 arguments.append(typeInfoC.second, typeInfoC.first); 192 unsigned numArgs = arguments.size(); 193 if (op.args().size() != numArgs) 194 return op.emitOpError() << "expected " << numArgs << " arguments"; 195 for (unsigned i = 0; i < numArgs; i++) { 196 if (op.args()[i].getType() != arguments[i]) 197 return op.emitOpError() 198 << "expected argument " << i << " to be of type " << arguments[i]; 199 } 200 Type dstType = LLVM::LLVMStructType::getLiteral( 201 op.getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first)); 202 if (op.getType() != dstType) 203 return op.emitOpError("expected destination type is a structure of ") 204 << typeInfoC.second << " elements of type " << typeInfoC.first; 205 return success(); 206 } 207 208 //===----------------------------------------------------------------------===// 209 // NVVMDialect initialization, type parsing, and registration. 210 //===----------------------------------------------------------------------===// 211 212 // TODO: This should be the llvm.nvvm dialect once this is supported. 213 void NVVMDialect::initialize() { 214 addOperations< 215 #define GET_OP_LIST 216 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 217 >(); 218 addAttributes< 219 #define GET_ATTRDEF_LIST 220 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 221 >(); 222 223 // Support unknown operations because not all NVVM operations are 224 // registered. 225 allowUnknownOperations(); 226 } 227 228 LogicalResult NVVMDialect::verifyOperationAttribute(Operation *op, 229 NamedAttribute attr) { 230 // Kernel function attribute should be attached to functions. 231 if (attr.getName() == NVVMDialect::getKernelFuncAttrName()) { 232 if (!isa<LLVM::LLVMFuncOp>(op)) { 233 return op->emitError() << "'" << NVVMDialect::getKernelFuncAttrName() 234 << "' attribute attached to unexpected op"; 235 } 236 } 237 return success(); 238 } 239 240 #define GET_OP_CLASSES 241 #include "mlir/Dialect/LLVMIR/NVVMOps.cpp.inc" 242 243 #define GET_ATTRDEF_CLASSES 244 #include "mlir/Dialect/LLVMIR/NVVMOpsAttributes.cpp.inc" 245