1 //===- ROCDLDialect.cpp - ROCDL IR Ops and Dialect registration -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the types and operation details for the ROCDL IR dialect in
10 // MLIR, and the LLVM IR dialect. It also registers the dialect.
11 //
12 // The ROCDL dialect only contains GPU specific additions on top of the general
13 // LLVM dialect.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
18
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Operation.h"
24 #include "llvm/AsmParser/Parser.h"
25 #include "llvm/IR/Attributes.h"
26 #include "llvm/IR/Function.h"
27 #include "llvm/IR/Type.h"
28 #include "llvm/Support/SourceMgr.h"
29
30 using namespace mlir;
31 using namespace ROCDL;
32
33 #include "mlir/Dialect/LLVMIR/ROCDLOpsDialect.cpp.inc"
34
35 //===----------------------------------------------------------------------===//
36 // Parsing for ROCDL ops
37 //===----------------------------------------------------------------------===//
38
39 // <operation> ::=
40 // `llvm.amdgcn.buffer.load.* %rsrc, %vindex, %offset, %glc, %slc :
41 // result_type`
parse(OpAsmParser & parser,OperationState & result)42 ParseResult MubufLoadOp::parse(OpAsmParser &parser, OperationState &result) {
43 SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
44 Type type;
45 if (parser.parseOperandList(ops, 5) || parser.parseColonType(type) ||
46 parser.addTypeToList(type, result.types))
47 return failure();
48
49 MLIRContext *context = parser.getContext();
50 auto int32Ty = IntegerType::get(context, 32);
51 auto int1Ty = IntegerType::get(context, 1);
52 auto i32x4Ty = LLVM::getFixedVectorType(int32Ty, 4);
53 return parser.resolveOperands(ops,
54 {i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty},
55 parser.getNameLoc(), result.operands);
56 }
57
print(OpAsmPrinter & p)58 void MubufLoadOp::print(OpAsmPrinter &p) {
59 p << " " << getOperands() << " : " << (*this)->getResultTypes();
60 }
61
62 // <operation> ::=
63 // `llvm.amdgcn.buffer.store.* %vdata, %rsrc, %vindex, %offset, %glc, %slc :
64 // result_type`
parse(OpAsmParser & parser,OperationState & result)65 ParseResult MubufStoreOp::parse(OpAsmParser &parser, OperationState &result) {
66 SmallVector<OpAsmParser::UnresolvedOperand, 8> ops;
67 Type type;
68 if (parser.parseOperandList(ops, 6) || parser.parseColonType(type))
69 return failure();
70
71 MLIRContext *context = parser.getContext();
72 auto int32Ty = IntegerType::get(context, 32);
73 auto int1Ty = IntegerType::get(context, 1);
74 auto i32x4Ty = LLVM::getFixedVectorType(int32Ty, 4);
75
76 if (parser.resolveOperands(ops,
77 {type, i32x4Ty, int32Ty, int32Ty, int1Ty, int1Ty},
78 parser.getNameLoc(), result.operands))
79 return failure();
80 return success();
81 }
82
print(OpAsmPrinter & p)83 void MubufStoreOp::print(OpAsmPrinter &p) {
84 p << " " << getOperands() << " : " << getVdata().getType();
85 }
86
87 // <operation> ::=
88 // `llvm.amdgcn.raw.buffer.load.* %rsrc, %offset, %soffset, %aux
89 // : result_type`
parse(OpAsmParser & parser,OperationState & result)90 ParseResult RawBufferLoadOp::parse(OpAsmParser &parser,
91 OperationState &result) {
92 SmallVector<OpAsmParser::UnresolvedOperand, 4> ops;
93 Type type;
94 if (parser.parseOperandList(ops, 4) || parser.parseColonType(type) ||
95 parser.addTypeToList(type, result.types))
96 return failure();
97
98 auto bldr = parser.getBuilder();
99 auto int32Ty = bldr.getI32Type();
100 auto i32x4Ty = VectorType::get({4}, int32Ty);
101 return parser.resolveOperands(ops, {i32x4Ty, int32Ty, int32Ty, int32Ty},
102 parser.getNameLoc(), result.operands);
103 }
104
print(OpAsmPrinter & p)105 void RawBufferLoadOp::print(OpAsmPrinter &p) {
106 p << " " << getOperands() << " : " << getRes().getType();
107 }
108
109 // <operation> ::=
110 // `llvm.amdgcn.raw.buffer.store.* %vdata, %rsrc, %offset,
111 // %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)112 ParseResult RawBufferStoreOp::parse(OpAsmParser &parser,
113 OperationState &result) {
114 SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
115 Type type;
116 if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
117 return failure();
118
119 auto bldr = parser.getBuilder();
120 auto int32Ty = bldr.getI32Type();
121 auto i32x4Ty = VectorType::get({4}, int32Ty);
122
123 if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
124 parser.getNameLoc(), result.operands))
125 return failure();
126 return success();
127 }
128
print(OpAsmPrinter & p)129 void RawBufferStoreOp::print(OpAsmPrinter &p) {
130 p << " " << getOperands() << " : " << getVdata().getType();
131 }
132
133 // <operation> ::=
134 // `llvm.amdgcn.raw.buffer.atomic.fadd.* %vdata, %rsrc, %offset,
135 // %soffset, %aux : result_type`
parse(OpAsmParser & parser,OperationState & result)136 ParseResult RawBufferAtomicFAddOp::parse(OpAsmParser &parser,
137 OperationState &result) {
138 SmallVector<OpAsmParser::UnresolvedOperand, 5> ops;
139 Type type;
140 if (parser.parseOperandList(ops, 5) || parser.parseColonType(type))
141 return failure();
142
143 auto bldr = parser.getBuilder();
144 auto int32Ty = bldr.getI32Type();
145 auto i32x4Ty = VectorType::get({4}, int32Ty);
146
147 if (parser.resolveOperands(ops, {type, i32x4Ty, int32Ty, int32Ty, int32Ty},
148 parser.getNameLoc(), result.operands))
149 return failure();
150 return success();
151 }
152
print(mlir::OpAsmPrinter & p)153 void RawBufferAtomicFAddOp::print(mlir::OpAsmPrinter &p) {
154 p << " " << getOperands() << " : " << getVdata().getType();
155 }
156
157 //===----------------------------------------------------------------------===//
158 // ROCDLDialect initialization, type parsing, and registration.
159 //===----------------------------------------------------------------------===//
160
161 // TODO: This should be the llvm.rocdl dialect once this is supported.
initialize()162 void ROCDLDialect::initialize() {
163 addOperations<
164 #define GET_OP_LIST
165 #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
166 >();
167
168 // Support unknown operations because not all ROCDL operations are registered.
169 allowUnknownOperations();
170 }
171
verifyOperationAttribute(Operation * op,NamedAttribute attr)172 LogicalResult ROCDLDialect::verifyOperationAttribute(Operation *op,
173 NamedAttribute attr) {
174 // Kernel function attribute should be attached to functions.
175 if (attr.getName() == ROCDLDialect::getKernelFuncAttrName()) {
176 if (!isa<LLVM::LLVMFuncOp>(op)) {
177 return op->emitError() << "'" << ROCDLDialect::getKernelFuncAttrName()
178 << "' attribute attached to unexpected op";
179 }
180 }
181 return success();
182 }
183
184 #define GET_OP_CLASSES
185 #include "mlir/Dialect/LLVMIR/ROCDLOps.cpp.inc"
186