1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/FunctionImplementation.h"
25 #include "mlir/IR/OpDefinition.h"
26 #include "mlir/IR/OpImplementation.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Interfaces/CallInterfaces.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/APInt.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/bit.h"
34
35 using namespace mlir;
36
37 // TODO: generate these strings using ODS.
38 static constexpr const char kMemoryAccessAttrName[] = "memory_access";
39 static constexpr const char kSourceMemoryAccessAttrName[] =
40 "source_memory_access";
41 static constexpr const char kAlignmentAttrName[] = "alignment";
42 static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
43 static constexpr const char kBranchWeightAttrName[] = "branch_weights";
44 static constexpr const char kCallee[] = "callee";
45 static constexpr const char kClusterSize[] = "cluster_size";
46 static constexpr const char kControl[] = "control";
47 static constexpr const char kDefaultValueAttrName[] = "default_value";
48 static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
49 static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
50 static constexpr const char kFnNameAttrName[] = "fn";
51 static constexpr const char kGroupOperationAttrName[] = "group_operation";
52 static constexpr const char kIndicesAttrName[] = "indices";
53 static constexpr const char kInitializerAttrName[] = "initializer";
54 static constexpr const char kInterfaceAttrName[] = "interface";
55 static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
56 static constexpr const char kSemanticsAttrName[] = "semantics";
57 static constexpr const char kSpecIdAttrName[] = "spec_id";
58 static constexpr const char kTypeAttrName[] = "type";
59 static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
60 static constexpr const char kValueAttrName[] = "value";
61 static constexpr const char kValuesAttrName[] = "values";
62 static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
63
64 //===----------------------------------------------------------------------===//
65 // Common utility functions
66 //===----------------------------------------------------------------------===//
67
parseOneResultSameOperandTypeOp(OpAsmParser & parser,OperationState & result)68 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
69 OperationState &result) {
70 SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
71 Type type;
72 // If the operand list is in-between parentheses, then we have a generic form.
73 // (see the fallback in `printOneResultOp`).
74 SMLoc loc = parser.getCurrentLocation();
75 if (!parser.parseOptionalLParen()) {
76 if (parser.parseOperandList(ops) || parser.parseRParen() ||
77 parser.parseOptionalAttrDict(result.attributes) ||
78 parser.parseColon() || parser.parseType(type))
79 return failure();
80 auto fnType = type.dyn_cast<FunctionType>();
81 if (!fnType) {
82 parser.emitError(loc, "expected function type");
83 return failure();
84 }
85 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
86 return failure();
87 result.addTypes(fnType.getResults());
88 return success();
89 }
90 return failure(parser.parseOperandList(ops) ||
91 parser.parseOptionalAttrDict(result.attributes) ||
92 parser.parseColonType(type) ||
93 parser.resolveOperands(ops, type, result.operands) ||
94 parser.addTypeToList(type, result.types));
95 }
96
printOneResultOp(Operation * op,OpAsmPrinter & p)97 static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
98 assert(op->getNumResults() == 1 && "op should have one result");
99
100 // If not all the operand and result types are the same, just use the
101 // generic assembly form to avoid omitting information in printing.
102 auto resultType = op->getResult(0).getType();
103 if (llvm::any_of(op->getOperandTypes(),
104 [&](Type type) { return type != resultType; })) {
105 p.printGenericOp(op, /*printOpName=*/false);
106 return;
107 }
108
109 p << ' ';
110 p.printOperands(op->getOperands());
111 p.printOptionalAttrDict(op->getAttrs());
112 // Now we can output only one type for all operands and the result.
113 p << " : " << resultType;
114 }
115
116 /// Returns true if the given op is a function-like op or nested in a
117 /// function-like op without a module-like op in the middle.
isNestedInFunctionOpInterface(Operation * op)118 static bool isNestedInFunctionOpInterface(Operation *op) {
119 if (!op)
120 return false;
121 if (op->hasTrait<OpTrait::SymbolTable>())
122 return false;
123 if (isa<FunctionOpInterface>(op))
124 return true;
125 return isNestedInFunctionOpInterface(op->getParentOp());
126 }
127
128 /// Returns true if the given op is an module-like op that maintains a symbol
129 /// table.
isDirectInModuleLikeOp(Operation * op)130 static bool isDirectInModuleLikeOp(Operation *op) {
131 return op && op->hasTrait<OpTrait::SymbolTable>();
132 }
133
extractValueFromConstOp(Operation * op,int32_t & value)134 static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
135 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
136 if (!constOp) {
137 return failure();
138 }
139 auto valueAttr = constOp.value();
140 auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
141 if (!integerValueAttr) {
142 return failure();
143 }
144
145 if (integerValueAttr.getType().isSignlessInteger())
146 value = integerValueAttr.getInt();
147 else
148 value = integerValueAttr.getSInt();
149
150 return success();
151 }
152
153 template <typename Ty>
154 static ArrayAttr
getStrArrayAttrForEnumList(Builder & builder,ArrayRef<Ty> enumValues,function_ref<StringRef (Ty)> stringifyFn)155 getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
156 function_ref<StringRef(Ty)> stringifyFn) {
157 if (enumValues.empty()) {
158 return nullptr;
159 }
160 SmallVector<StringRef, 1> enumValStrs;
161 enumValStrs.reserve(enumValues.size());
162 for (auto val : enumValues) {
163 enumValStrs.emplace_back(stringifyFn(val));
164 }
165 return builder.getStrArrayAttr(enumValStrs);
166 }
167
168 /// Parses the next string attribute in `parser` as an enumerant of the given
169 /// `EnumClass`.
170 template <typename EnumClass>
171 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,StringRef attrName=spirv::attributeName<EnumClass> ())172 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
173 StringRef attrName = spirv::attributeName<EnumClass>()) {
174 Attribute attrVal;
175 NamedAttrList attr;
176 auto loc = parser.getCurrentLocation();
177 if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
178 attrName, attr)) {
179 return failure();
180 }
181 if (!attrVal.isa<StringAttr>()) {
182 return parser.emitError(loc, "expected ")
183 << attrName << " attribute specified as string";
184 }
185 auto attrOptional =
186 spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
187 if (!attrOptional) {
188 return parser.emitError(loc, "invalid ")
189 << attrName << " attribute specification: " << attrVal;
190 }
191 value = *attrOptional;
192 return success();
193 }
194
195 /// Parses the next string attribute in `parser` as an enumerant of the given
196 /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
197 /// attribute with the enum class's name as attribute name.
198 template <typename EnumClass>
199 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())200 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
201 StringRef attrName = spirv::attributeName<EnumClass>()) {
202 if (parseEnumStrAttr(value, parser)) {
203 return failure();
204 }
205 state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
206 llvm::bit_cast<int32_t>(value)));
207 return success();
208 }
209
210 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
211 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
212 /// the enum class's name as attribute name.
213 template <typename EnumClass>
214 static ParseResult
parseEnumKeywordAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())215 parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
216 OperationState &state,
217 StringRef attrName = spirv::attributeName<EnumClass>()) {
218 if (parseEnumKeywordAttr(value, parser)) {
219 return failure();
220 }
221 state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
222 llvm::bit_cast<int32_t>(value)));
223 return success();
224 }
225
226 /// Parses Function, Selection and Loop control attributes. If no control is
227 /// specified, "None" is used as a default.
228 template <typename EnumClass>
229 static ParseResult
parseControlAttribute(OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())230 parseControlAttribute(OpAsmParser &parser, OperationState &state,
231 StringRef attrName = spirv::attributeName<EnumClass>()) {
232 if (succeeded(parser.parseOptionalKeyword(kControl))) {
233 EnumClass control;
234 if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
235 parser.parseRParen())
236 return failure();
237 return success();
238 }
239 // Set control to "None" otherwise.
240 Builder builder = parser.getBuilder();
241 state.addAttribute(attrName, builder.getI32IntegerAttr(0));
242 return success();
243 }
244
245 /// Parses optional memory access attributes attached to a memory access
246 /// operand/pointer. Specifically, parses the following syntax:
247 /// (`[` memory-access `]`)?
248 /// where:
249 /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
250 /// integer-literal | `"NonTemporal"`
parseMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)251 static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
252 OperationState &state) {
253 // Parse an optional list of attributes staring with '['
254 if (parser.parseOptionalLSquare()) {
255 // Nothing to do
256 return success();
257 }
258
259 spirv::MemoryAccess memoryAccessAttr;
260 if (parseEnumStrAttr(memoryAccessAttr, parser, state,
261 kMemoryAccessAttrName)) {
262 return failure();
263 }
264
265 if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
266 // Parse integer attribute for alignment.
267 Attribute alignmentAttr;
268 Type i32Type = parser.getBuilder().getIntegerType(32);
269 if (parser.parseComma() ||
270 parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
271 state.attributes)) {
272 return failure();
273 }
274 }
275 return parser.parseRSquare();
276 }
277
278 // TODO Make sure to merge this and the previous function into one template
279 // parameterized by memory access attribute name and alignment. Doing so now
280 // results in VS2017 in producing an internal error (at the call site) that's
281 // not detailed enough to understand what is happening.
parseSourceMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)282 static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
283 OperationState &state) {
284 // Parse an optional list of attributes staring with '['
285 if (parser.parseOptionalLSquare()) {
286 // Nothing to do
287 return success();
288 }
289
290 spirv::MemoryAccess memoryAccessAttr;
291 if (parseEnumStrAttr(memoryAccessAttr, parser, state,
292 kSourceMemoryAccessAttrName)) {
293 return failure();
294 }
295
296 if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
297 // Parse integer attribute for alignment.
298 Attribute alignmentAttr;
299 Type i32Type = parser.getBuilder().getIntegerType(32);
300 if (parser.parseComma() ||
301 parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
302 state.attributes)) {
303 return failure();
304 }
305 }
306 return parser.parseRSquare();
307 }
308
309 template <typename MemoryOpTy>
printMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)310 static void printMemoryAccessAttribute(
311 MemoryOpTy memoryOp, OpAsmPrinter &printer,
312 SmallVectorImpl<StringRef> &elidedAttrs,
313 Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
314 Optional<uint32_t> alignmentAttrValue = None) {
315 // Print optional memory access attribute.
316 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
317 : memoryOp.memory_access())) {
318 elidedAttrs.push_back(kMemoryAccessAttrName);
319
320 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
321
322 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
323 // Print integer alignment attribute.
324 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
325 : memoryOp.alignment())) {
326 elidedAttrs.push_back(kAlignmentAttrName);
327 printer << ", " << alignment;
328 }
329 }
330 printer << "]";
331 }
332 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
333 }
334
335 // TODO Make sure to merge this and the previous function into one template
336 // parameterized by memory access attribute name and alignment. Doing so now
337 // results in VS2017 in producing an internal error (at the call site) that's
338 // not detailed enough to understand what is happening.
339 template <typename MemoryOpTy>
printSourceMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)340 static void printSourceMemoryAccessAttribute(
341 MemoryOpTy memoryOp, OpAsmPrinter &printer,
342 SmallVectorImpl<StringRef> &elidedAttrs,
343 Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
344 Optional<uint32_t> alignmentAttrValue = None) {
345
346 printer << ", ";
347
348 // Print optional memory access attribute.
349 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
350 : memoryOp.memory_access())) {
351 elidedAttrs.push_back(kSourceMemoryAccessAttrName);
352
353 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
354
355 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
356 // Print integer alignment attribute.
357 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
358 : memoryOp.alignment())) {
359 elidedAttrs.push_back(kSourceAlignmentAttrName);
360 printer << ", " << alignment;
361 }
362 }
363 printer << "]";
364 }
365 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
366 }
367
parseImageOperands(OpAsmParser & parser,spirv::ImageOperandsAttr & attr)368 static ParseResult parseImageOperands(OpAsmParser &parser,
369 spirv::ImageOperandsAttr &attr) {
370 // Expect image operands
371 if (parser.parseOptionalLSquare())
372 return success();
373
374 spirv::ImageOperands imageOperands;
375 if (parseEnumStrAttr(imageOperands, parser))
376 return failure();
377
378 attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
379
380 return parser.parseRSquare();
381 }
382
printImageOperands(OpAsmPrinter & printer,Operation * imageOp,spirv::ImageOperandsAttr attr)383 static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
384 spirv::ImageOperandsAttr attr) {
385 if (attr) {
386 auto strImageOperands = stringifyImageOperands(attr.getValue());
387 printer << "[\"" << strImageOperands << "\"]";
388 }
389 }
390
391 template <typename Op>
verifyImageOperands(Op imageOp,spirv::ImageOperandsAttr attr,Operation::operand_range operands)392 static LogicalResult verifyImageOperands(Op imageOp,
393 spirv::ImageOperandsAttr attr,
394 Operation::operand_range operands) {
395 if (!attr) {
396 if (operands.empty())
397 return success();
398
399 return imageOp.emitError("the Image Operands should encode what operands "
400 "follow, as per Image Operands");
401 }
402
403 // TODO: Add the validation rules for the following Image Operands.
404 spirv::ImageOperands noSupportOperands =
405 spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
406 spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
407 spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
408 spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
409 spirv::ImageOperands::MakeTexelAvailable |
410 spirv::ImageOperands::MakeTexelVisible |
411 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
412
413 if (spirv::bitEnumContains(attr.getValue(), noSupportOperands))
414 llvm_unreachable("unimplemented operands of Image Operands");
415
416 return success();
417 }
418
verifyCastOp(Operation * op,bool requireSameBitWidth=true,bool skipBitWidthCheck=false)419 static LogicalResult verifyCastOp(Operation *op,
420 bool requireSameBitWidth = true,
421 bool skipBitWidthCheck = false) {
422 // Some CastOps have no limit on bit widths for result and operand type.
423 if (skipBitWidthCheck)
424 return success();
425
426 Type operandType = op->getOperand(0).getType();
427 Type resultType = op->getResult(0).getType();
428
429 // ODS checks that result type and operand type have the same shape.
430 if (auto vectorType = operandType.dyn_cast<VectorType>()) {
431 operandType = vectorType.getElementType();
432 resultType = resultType.cast<VectorType>().getElementType();
433 }
434
435 if (auto coopMatrixType =
436 operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
437 operandType = coopMatrixType.getElementType();
438 resultType =
439 resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
440 }
441
442 auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
443 auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
444 auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
445
446 if (requireSameBitWidth) {
447 if (!isSameBitWidth) {
448 return op->emitOpError(
449 "expected the same bit widths for operand type and result "
450 "type, but provided ")
451 << operandType << " and " << resultType;
452 }
453 return success();
454 }
455
456 if (isSameBitWidth) {
457 return op->emitOpError(
458 "expected the different bit widths for operand type and result "
459 "type, but provided ")
460 << operandType << " and " << resultType;
461 }
462 return success();
463 }
464
465 template <typename MemoryOpTy>
verifyMemoryAccessAttribute(MemoryOpTy memoryOp)466 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
467 // ODS checks for attributes values. Just need to verify that if the
468 // memory-access attribute is Aligned, then the alignment attribute must be
469 // present.
470 auto *op = memoryOp.getOperation();
471 auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
472 if (!memAccessAttr) {
473 // Alignment attribute shouldn't be present if memory access attribute is
474 // not present.
475 if (op->getAttr(kAlignmentAttrName)) {
476 return memoryOp.emitOpError(
477 "invalid alignment specification without aligned memory access "
478 "specification");
479 }
480 return success();
481 }
482
483 auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
484 auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
485
486 if (!memAccess) {
487 return memoryOp.emitOpError("invalid memory access specifier: ")
488 << memAccessVal;
489 }
490
491 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
492 if (!op->getAttr(kAlignmentAttrName)) {
493 return memoryOp.emitOpError("missing alignment value");
494 }
495 } else {
496 if (op->getAttr(kAlignmentAttrName)) {
497 return memoryOp.emitOpError(
498 "invalid alignment specification with non-aligned memory access "
499 "specification");
500 }
501 }
502 return success();
503 }
504
505 // TODO Make sure to merge this and the previous function into one template
506 // parameterized by memory access attribute name and alignment. Doing so now
507 // results in VS2017 in producing an internal error (at the call site) that's
508 // not detailed enough to understand what is happening.
509 template <typename MemoryOpTy>
verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)510 static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
511 // ODS checks for attributes values. Just need to verify that if the
512 // memory-access attribute is Aligned, then the alignment attribute must be
513 // present.
514 auto *op = memoryOp.getOperation();
515 auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
516 if (!memAccessAttr) {
517 // Alignment attribute shouldn't be present if memory access attribute is
518 // not present.
519 if (op->getAttr(kSourceAlignmentAttrName)) {
520 return memoryOp.emitOpError(
521 "invalid alignment specification without aligned memory access "
522 "specification");
523 }
524 return success();
525 }
526
527 auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
528 auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
529
530 if (!memAccess) {
531 return memoryOp.emitOpError("invalid memory access specifier: ")
532 << memAccessVal;
533 }
534
535 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
536 if (!op->getAttr(kSourceAlignmentAttrName)) {
537 return memoryOp.emitOpError("missing alignment value");
538 }
539 } else {
540 if (op->getAttr(kSourceAlignmentAttrName)) {
541 return memoryOp.emitOpError(
542 "invalid alignment specification with non-aligned memory access "
543 "specification");
544 }
545 }
546 return success();
547 }
548
549 static LogicalResult
verifyMemorySemantics(Operation * op,spirv::MemorySemantics memorySemantics)550 verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
551 // According to the SPIR-V specification:
552 // "Despite being a mask and allowing multiple bits to be combined, it is
553 // invalid for more than one of these four bits to be set: Acquire, Release,
554 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
555 // Release semantics is done by setting the AcquireRelease bit, not by setting
556 // two bits."
557 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
558 spirv::MemorySemantics::Release |
559 spirv::MemorySemantics::AcquireRelease |
560 spirv::MemorySemantics::SequentiallyConsistent;
561
562 auto bitCount = llvm::countPopulation(
563 static_cast<uint32_t>(memorySemantics & atMostOneInSet));
564 if (bitCount > 1) {
565 return op->emitError(
566 "expected at most one of these four memory constraints "
567 "to be set: `Acquire`, `Release`,"
568 "`AcquireRelease` or `SequentiallyConsistent`");
569 }
570 return success();
571 }
572
573 template <typename LoadStoreOpTy>
verifyLoadStorePtrAndValTypes(LoadStoreOpTy op,Value ptr,Value val)574 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
575 Value val) {
576 // ODS already checks ptr is spirv::PointerType. Just check that the pointee
577 // type of the pointer and the type of the value are the same
578 //
579 // TODO: Check that the value type satisfies restrictions of
580 // SPIR-V OpLoad/OpStore operations
581 if (val.getType() !=
582 ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
583 return op.emitOpError("mismatch in result type and pointer type");
584 }
585 return success();
586 }
587
588 template <typename BlockReadWriteOpTy>
verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,Value ptr,Value val)589 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
590 Value ptr, Value val) {
591 auto valType = val.getType();
592 if (auto valVecTy = valType.dyn_cast<VectorType>())
593 valType = valVecTy.getElementType();
594
595 if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
596 return op.emitOpError("mismatch in result type and pointer type");
597 }
598 return success();
599 }
600
parseVariableDecorations(OpAsmParser & parser,OperationState & state)601 static ParseResult parseVariableDecorations(OpAsmParser &parser,
602 OperationState &state) {
603 auto builtInName = llvm::convertToSnakeFromCamelCase(
604 stringifyDecoration(spirv::Decoration::BuiltIn));
605 if (succeeded(parser.parseOptionalKeyword("bind"))) {
606 Attribute set, binding;
607 // Parse optional descriptor binding
608 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
609 stringifyDecoration(spirv::Decoration::DescriptorSet));
610 auto bindingName = llvm::convertToSnakeFromCamelCase(
611 stringifyDecoration(spirv::Decoration::Binding));
612 Type i32Type = parser.getBuilder().getIntegerType(32);
613 if (parser.parseLParen() ||
614 parser.parseAttribute(set, i32Type, descriptorSetName,
615 state.attributes) ||
616 parser.parseComma() ||
617 parser.parseAttribute(binding, i32Type, bindingName,
618 state.attributes) ||
619 parser.parseRParen()) {
620 return failure();
621 }
622 } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
623 StringAttr builtIn;
624 if (parser.parseLParen() ||
625 parser.parseAttribute(builtIn, builtInName, state.attributes) ||
626 parser.parseRParen()) {
627 return failure();
628 }
629 }
630
631 // Parse other attributes
632 if (parser.parseOptionalAttrDict(state.attributes))
633 return failure();
634
635 return success();
636 }
637
printVariableDecorations(Operation * op,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs)638 static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
639 SmallVectorImpl<StringRef> &elidedAttrs) {
640 // Print optional descriptor binding
641 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
642 stringifyDecoration(spirv::Decoration::DescriptorSet));
643 auto bindingName = llvm::convertToSnakeFromCamelCase(
644 stringifyDecoration(spirv::Decoration::Binding));
645 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
646 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
647 if (descriptorSet && binding) {
648 elidedAttrs.push_back(descriptorSetName);
649 elidedAttrs.push_back(bindingName);
650 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
651 << ")";
652 }
653
654 // Print BuiltIn attribute if present
655 auto builtInName = llvm::convertToSnakeFromCamelCase(
656 stringifyDecoration(spirv::Decoration::BuiltIn));
657 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
658 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
659 elidedAttrs.push_back(builtInName);
660 }
661
662 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
663 }
664
665 // Get bit width of types.
getBitWidth(Type type)666 static unsigned getBitWidth(Type type) {
667 if (type.isa<spirv::PointerType>()) {
668 // Just return 64 bits for pointer types for now.
669 // TODO: Make sure not caller relies on the actual pointer width value.
670 return 64;
671 }
672
673 if (type.isIntOrFloat())
674 return type.getIntOrFloatBitWidth();
675
676 if (auto vectorType = type.dyn_cast<VectorType>()) {
677 assert(vectorType.getElementType().isIntOrFloat());
678 return vectorType.getNumElements() *
679 vectorType.getElementType().getIntOrFloatBitWidth();
680 }
681 llvm_unreachable("unhandled bit width computation for type");
682 }
683
684 /// Walks the given type hierarchy with the given indices, potentially down
685 /// to component granularity, to select an element type. Returns null type and
686 /// emits errors with the given loc on failure.
687 static Type
getElementType(Type type,ArrayRef<int32_t> indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)688 getElementType(Type type, ArrayRef<int32_t> indices,
689 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
690 if (indices.empty()) {
691 emitErrorFn("expected at least one index for spv.CompositeExtract");
692 return nullptr;
693 }
694
695 for (auto index : indices) {
696 if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
697 if (cType.hasCompileTimeKnownNumElements() &&
698 (index < 0 ||
699 static_cast<uint64_t>(index) >= cType.getNumElements())) {
700 emitErrorFn("index ") << index << " out of bounds for " << type;
701 return nullptr;
702 }
703 type = cType.getElementType(index);
704 } else {
705 emitErrorFn("cannot extract from non-composite type ")
706 << type << " with index " << index;
707 return nullptr;
708 }
709 }
710 return type;
711 }
712
713 static Type
getElementType(Type type,Attribute indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)714 getElementType(Type type, Attribute indices,
715 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
716 auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
717 if (!indicesArrayAttr) {
718 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
719 return nullptr;
720 }
721 if (indicesArrayAttr.empty()) {
722 emitErrorFn("expected at least one index for spv.CompositeExtract");
723 return nullptr;
724 }
725
726 SmallVector<int32_t, 2> indexVals;
727 for (auto indexAttr : indicesArrayAttr) {
728 auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
729 if (!indexIntAttr) {
730 emitErrorFn("expected an 32-bit integer for index, but found '")
731 << indexAttr << "'";
732 return nullptr;
733 }
734 indexVals.push_back(indexIntAttr.getInt());
735 }
736 return getElementType(type, indexVals, emitErrorFn);
737 }
738
getElementType(Type type,Attribute indices,Location loc)739 static Type getElementType(Type type, Attribute indices, Location loc) {
740 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
741 return ::mlir::emitError(loc, err);
742 };
743 return getElementType(type, indices, errorFn);
744 }
745
getElementType(Type type,Attribute indices,OpAsmParser & parser,SMLoc loc)746 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
747 SMLoc loc) {
748 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
749 return parser.emitError(loc, err);
750 };
751 return getElementType(type, indices, errorFn);
752 }
753
754 /// Returns true if the given `block` only contains one `spv.mlir.merge` op.
isMergeBlock(Block & block)755 static inline bool isMergeBlock(Block &block) {
756 return !block.empty() && std::next(block.begin()) == block.end() &&
757 isa<spirv::MergeOp>(block.front());
758 }
759
760 //===----------------------------------------------------------------------===//
761 // Common parsers and printers
762 //===----------------------------------------------------------------------===//
763
764 // Parses an atomic update op. If the update op does not take a value (like
765 // AtomicIIncrement) `hasValue` must be false.
parseAtomicUpdateOp(OpAsmParser & parser,OperationState & state,bool hasValue)766 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
767 OperationState &state, bool hasValue) {
768 spirv::Scope scope;
769 spirv::MemorySemantics memoryScope;
770 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
771 OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
772 Type type;
773 SMLoc loc;
774 if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
775 parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
776 parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
777 parser.getCurrentLocation(&loc) || parser.parseColonType(type))
778 return failure();
779
780 auto ptrType = type.dyn_cast<spirv::PointerType>();
781 if (!ptrType)
782 return parser.emitError(loc, "expected pointer type");
783
784 SmallVector<Type, 2> operandTypes;
785 operandTypes.push_back(ptrType);
786 if (hasValue)
787 operandTypes.push_back(ptrType.getPointeeType());
788 if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
789 state.operands))
790 return failure();
791 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
792 }
793
794 // Prints an atomic update op.
printAtomicUpdateOp(Operation * op,OpAsmPrinter & printer)795 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
796 printer << " \"";
797 auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
798 printer << spirv::stringifyScope(
799 static_cast<spirv::Scope>(scopeAttr.getInt()))
800 << "\" \"";
801 auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
802 printer << spirv::stringifyMemorySemantics(
803 static_cast<spirv::MemorySemantics>(
804 memorySemanticsAttr.getInt()))
805 << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
806 }
807
808 template <typename T>
809 static StringRef stringifyTypeName();
810
811 template <>
stringifyTypeName()812 StringRef stringifyTypeName<IntegerType>() {
813 return "integer";
814 }
815
816 template <>
stringifyTypeName()817 StringRef stringifyTypeName<FloatType>() {
818 return "float";
819 }
820
821 // Verifies an atomic update op.
822 template <typename ExpectedElementType>
verifyAtomicUpdateOp(Operation * op)823 static LogicalResult verifyAtomicUpdateOp(Operation *op) {
824 auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
825 auto elementType = ptrType.getPointeeType();
826 if (!elementType.isa<ExpectedElementType>())
827 return op->emitOpError() << "pointer operand must point to an "
828 << stringifyTypeName<ExpectedElementType>()
829 << " value, found " << elementType;
830
831 if (op->getNumOperands() > 1) {
832 auto valueType = op->getOperand(1).getType();
833 if (valueType != elementType)
834 return op->emitOpError("expected value to have the same type as the "
835 "pointer operand's pointee type ")
836 << elementType << ", but found " << valueType;
837 }
838 auto memorySemantics = static_cast<spirv::MemorySemantics>(
839 op->getAttrOfType<IntegerAttr>(kSemanticsAttrName).getInt());
840 if (failed(verifyMemorySemantics(op, memorySemantics))) {
841 return failure();
842 }
843 return success();
844 }
845
parseGroupNonUniformArithmeticOp(OpAsmParser & parser,OperationState & state)846 static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
847 OperationState &state) {
848 spirv::Scope executionScope;
849 spirv::GroupOperation groupOperation;
850 OpAsmParser::UnresolvedOperand valueInfo;
851 if (parseEnumStrAttr(executionScope, parser, state,
852 kExecutionScopeAttrName) ||
853 parseEnumStrAttr(groupOperation, parser, state,
854 kGroupOperationAttrName) ||
855 parser.parseOperand(valueInfo))
856 return failure();
857
858 Optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
859 if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
860 clusterSizeInfo = OpAsmParser::UnresolvedOperand();
861 if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
862 parser.parseRParen())
863 return failure();
864 }
865
866 Type resultType;
867 if (parser.parseColonType(resultType))
868 return failure();
869
870 if (parser.resolveOperand(valueInfo, resultType, state.operands))
871 return failure();
872
873 if (clusterSizeInfo) {
874 Type i32Type = parser.getBuilder().getIntegerType(32);
875 if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
876 return failure();
877 }
878
879 return parser.addTypeToList(resultType, state.types);
880 }
881
printGroupNonUniformArithmeticOp(Operation * groupOp,OpAsmPrinter & printer)882 static void printGroupNonUniformArithmeticOp(Operation *groupOp,
883 OpAsmPrinter &printer) {
884 printer << " \""
885 << stringifyScope(static_cast<spirv::Scope>(
886 groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
887 .getInt()))
888 << "\" \""
889 << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
890 groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
891 .getInt()))
892 << "\" " << groupOp->getOperand(0);
893
894 if (groupOp->getNumOperands() > 1)
895 printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
896 printer << " : " << groupOp->getResult(0).getType();
897 }
898
verifyGroupNonUniformArithmeticOp(Operation * groupOp)899 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
900 spirv::Scope scope = static_cast<spirv::Scope>(
901 groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
902 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
903 return groupOp->emitOpError(
904 "execution scope must be 'Workgroup' or 'Subgroup'");
905
906 spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
907 groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
908 if (operation == spirv::GroupOperation::ClusteredReduce &&
909 groupOp->getNumOperands() == 1)
910 return groupOp->emitOpError("cluster size operand must be provided for "
911 "'ClusteredReduce' group operation");
912 if (groupOp->getNumOperands() > 1) {
913 Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
914 int32_t clusterSize = 0;
915
916 // TODO: support specialization constant here.
917 if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
918 return groupOp->emitOpError(
919 "cluster size operand must come from a constant op");
920
921 if (!llvm::isPowerOf2_32(clusterSize))
922 return groupOp->emitOpError(
923 "cluster size operand must be a power of two");
924 }
925 return success();
926 }
927
928 /// Result of a logical op must be a scalar or vector of boolean type.
getUnaryOpResultType(Type operandType)929 static Type getUnaryOpResultType(Type operandType) {
930 Builder builder(operandType.getContext());
931 Type resultType = builder.getIntegerType(1);
932 if (auto vecType = operandType.dyn_cast<VectorType>())
933 return VectorType::get(vecType.getNumElements(), resultType);
934 return resultType;
935 }
936
verifyShiftOp(Operation * op)937 static LogicalResult verifyShiftOp(Operation *op) {
938 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
939 return op->emitError("expected the same type for the first operand and "
940 "result, but provided ")
941 << op->getOperand(0).getType() << " and "
942 << op->getResult(0).getType();
943 }
944 return success();
945 }
946
buildLogicalBinaryOp(OpBuilder & builder,OperationState & state,Value lhs,Value rhs)947 static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
948 Value lhs, Value rhs) {
949 assert(lhs.getType() == rhs.getType());
950
951 Type boolType = builder.getI1Type();
952 if (auto vecType = lhs.getType().dyn_cast<VectorType>())
953 boolType = VectorType::get(vecType.getShape(), boolType);
954 state.addTypes(boolType);
955
956 state.addOperands({lhs, rhs});
957 }
958
buildLogicalUnaryOp(OpBuilder & builder,OperationState & state,Value value)959 static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
960 Value value) {
961 Type boolType = builder.getI1Type();
962 if (auto vecType = value.getType().dyn_cast<VectorType>())
963 boolType = VectorType::get(vecType.getShape(), boolType);
964 state.addTypes(boolType);
965
966 state.addOperands(value);
967 }
968
969 //===----------------------------------------------------------------------===//
970 // spv.AccessChainOp
971 //===----------------------------------------------------------------------===//
972
getElementPtrType(Type type,ValueRange indices,Location baseLoc)973 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
974 auto ptrType = type.dyn_cast<spirv::PointerType>();
975 if (!ptrType) {
976 emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
977 "to composite type, but provided ")
978 << type;
979 return nullptr;
980 }
981
982 auto resultType = ptrType.getPointeeType();
983 auto resultStorageClass = ptrType.getStorageClass();
984 int32_t index = 0;
985
986 for (auto indexSSA : indices) {
987 auto cType = resultType.dyn_cast<spirv::CompositeType>();
988 if (!cType) {
989 emitError(baseLoc,
990 "'spv.AccessChain' op cannot extract from non-composite type ")
991 << resultType << " with index " << index;
992 return nullptr;
993 }
994 index = 0;
995 if (resultType.isa<spirv::StructType>()) {
996 Operation *op = indexSSA.getDefiningOp();
997 if (!op) {
998 emitError(baseLoc, "'spv.AccessChain' op index must be an "
999 "integer spv.Constant to access "
1000 "element of spv.struct");
1001 return nullptr;
1002 }
1003
1004 // TODO: this should be relaxed to allow
1005 // integer literals of other bitwidths.
1006 if (failed(extractValueFromConstOp(op, index))) {
1007 emitError(baseLoc,
1008 "'spv.AccessChain' index must be an integer spv.Constant to "
1009 "access element of spv.struct, but provided ")
1010 << op->getName();
1011 return nullptr;
1012 }
1013 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
1014 emitError(baseLoc, "'spv.AccessChain' op index ")
1015 << index << " out of bounds for " << resultType;
1016 return nullptr;
1017 }
1018 }
1019 resultType = cType.getElementType(index);
1020 }
1021 return spirv::PointerType::get(resultType, resultStorageClass);
1022 }
1023
build(OpBuilder & builder,OperationState & state,Value basePtr,ValueRange indices)1024 void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
1025 Value basePtr, ValueRange indices) {
1026 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
1027 assert(type && "Unable to deduce return type based on basePtr and indices");
1028 build(builder, state, type, basePtr, indices);
1029 }
1030
parse(OpAsmParser & parser,OperationState & state)1031 ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
1032 OperationState &state) {
1033 OpAsmParser::UnresolvedOperand ptrInfo;
1034 SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
1035 Type type;
1036 auto loc = parser.getCurrentLocation();
1037 SmallVector<Type, 4> indicesTypes;
1038
1039 if (parser.parseOperand(ptrInfo) ||
1040 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
1041 parser.parseColonType(type) ||
1042 parser.resolveOperand(ptrInfo, type, state.operands)) {
1043 return failure();
1044 }
1045
1046 // Check that the provided indices list is not empty before parsing their
1047 // type list.
1048 if (indicesInfo.empty()) {
1049 return mlir::emitError(state.location, "'spv.AccessChain' op expected at "
1050 "least one index ");
1051 }
1052
1053 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
1054 return failure();
1055
1056 // Check that the indices types list is not empty and that it has a one-to-one
1057 // mapping to the provided indices.
1058 if (indicesTypes.size() != indicesInfo.size()) {
1059 return mlir::emitError(state.location,
1060 "'spv.AccessChain' op indices types' count must be "
1061 "equal to indices info count");
1062 }
1063
1064 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
1065 return failure();
1066
1067 auto resultType = getElementPtrType(
1068 type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
1069 if (!resultType) {
1070 return failure();
1071 }
1072
1073 state.addTypes(resultType);
1074 return success();
1075 }
1076
1077 template <typename Op>
printAccessChain(Op op,ValueRange indices,OpAsmPrinter & printer)1078 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
1079 printer << ' ' << op.base_ptr() << '[' << indices
1080 << "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
1081 }
1082
print(OpAsmPrinter & printer)1083 void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
1084 printAccessChain(*this, indices(), printer);
1085 }
1086
1087 template <typename Op>
verifyAccessChain(Op accessChainOp,ValueRange indices)1088 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
1089 auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
1090 indices, accessChainOp.getLoc());
1091 if (!resultType)
1092 return failure();
1093
1094 auto providedResultType =
1095 accessChainOp.getType().template dyn_cast<spirv::PointerType>();
1096 if (!providedResultType)
1097 return accessChainOp.emitOpError(
1098 "result type must be a pointer, but provided")
1099 << providedResultType;
1100
1101 if (resultType != providedResultType)
1102 return accessChainOp.emitOpError("invalid result type: expected ")
1103 << resultType << ", but provided " << providedResultType;
1104
1105 return success();
1106 }
1107
verify()1108 LogicalResult spirv::AccessChainOp::verify() {
1109 return verifyAccessChain(*this, indices());
1110 }
1111
1112 //===----------------------------------------------------------------------===//
1113 // spv.mlir.addressof
1114 //===----------------------------------------------------------------------===//
1115
build(OpBuilder & builder,OperationState & state,spirv::GlobalVariableOp var)1116 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1117 spirv::GlobalVariableOp var) {
1118 build(builder, state, var.type(), SymbolRefAttr::get(var));
1119 }
1120
verify()1121 LogicalResult spirv::AddressOfOp::verify() {
1122 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1123 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
1124 variableAttr()));
1125 if (!varOp) {
1126 return emitOpError("expected spv.GlobalVariable symbol");
1127 }
1128 if (pointer().getType() != varOp.type()) {
1129 return emitOpError(
1130 "result type mismatch with the referenced global variable's type");
1131 }
1132 return success();
1133 }
1134
1135 template <typename T>
printAtomicCompareExchangeImpl(T atomOp,OpAsmPrinter & printer)1136 static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
1137 printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
1138 << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
1139 << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
1140 << atomOp.getOperands() << " : " << atomOp.pointer().getType();
1141 }
1142
parseAtomicCompareExchangeImpl(OpAsmParser & parser,OperationState & state)1143 static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
1144 OperationState &state) {
1145 spirv::Scope memoryScope;
1146 spirv::MemorySemantics equalSemantics, unequalSemantics;
1147 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
1148 Type type;
1149 if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1150 parseEnumStrAttr(equalSemantics, parser, state,
1151 kEqualSemanticsAttrName) ||
1152 parseEnumStrAttr(unequalSemantics, parser, state,
1153 kUnequalSemanticsAttrName) ||
1154 parser.parseOperandList(operandInfo, 3))
1155 return failure();
1156
1157 auto loc = parser.getCurrentLocation();
1158 if (parser.parseColonType(type))
1159 return failure();
1160
1161 auto ptrType = type.dyn_cast<spirv::PointerType>();
1162 if (!ptrType)
1163 return parser.emitError(loc, "expected pointer type");
1164
1165 if (parser.resolveOperands(
1166 operandInfo,
1167 {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1168 parser.getNameLoc(), state.operands))
1169 return failure();
1170
1171 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1172 }
1173
1174 template <typename T>
verifyAtomicCompareExchangeImpl(T atomOp)1175 static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
1176 // According to the spec:
1177 // "The type of Value must be the same as Result Type. The type of the value
1178 // pointed to by Pointer must be the same as Result Type. This type must also
1179 // match the type of Comparator."
1180 if (atomOp.getType() != atomOp.value().getType())
1181 return atomOp.emitOpError("value operand must have the same type as the op "
1182 "result, but found ")
1183 << atomOp.value().getType() << " vs " << atomOp.getType();
1184
1185 if (atomOp.getType() != atomOp.comparator().getType())
1186 return atomOp.emitOpError(
1187 "comparator operand must have the same type as the op "
1188 "result, but found ")
1189 << atomOp.comparator().getType() << " vs " << atomOp.getType();
1190
1191 Type pointeeType = atomOp.pointer()
1192 .getType()
1193 .template cast<spirv::PointerType>()
1194 .getPointeeType();
1195 if (atomOp.getType() != pointeeType)
1196 return atomOp.emitOpError(
1197 "pointer operand's pointee type must have the same "
1198 "as the op result type, but found ")
1199 << pointeeType << " vs " << atomOp.getType();
1200
1201 // TODO: Unequal cannot be set to Release or Acquire and Release.
1202 // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1203
1204 return success();
1205 }
1206
1207 //===----------------------------------------------------------------------===//
1208 // spv.AtomicAndOp
1209 //===----------------------------------------------------------------------===//
1210
verify()1211 LogicalResult spirv::AtomicAndOp::verify() {
1212 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1213 }
1214
parse(OpAsmParser & parser,OperationState & result)1215 ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser,
1216 OperationState &result) {
1217 return ::parseAtomicUpdateOp(parser, result, true);
1218 }
print(OpAsmPrinter & p)1219 void spirv::AtomicAndOp::print(OpAsmPrinter &p) {
1220 ::printAtomicUpdateOp(*this, p);
1221 }
1222
1223 //===----------------------------------------------------------------------===//
1224 // spv.AtomicCompareExchangeOp
1225 //===----------------------------------------------------------------------===//
1226
verify()1227 LogicalResult spirv::AtomicCompareExchangeOp::verify() {
1228 return ::verifyAtomicCompareExchangeImpl(*this);
1229 }
1230
parse(OpAsmParser & parser,OperationState & result)1231 ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser,
1232 OperationState &result) {
1233 return ::parseAtomicCompareExchangeImpl(parser, result);
1234 }
print(OpAsmPrinter & p)1235 void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
1236 ::printAtomicCompareExchangeImpl(*this, p);
1237 }
1238
1239 //===----------------------------------------------------------------------===//
1240 // spv.AtomicCompareExchangeWeakOp
1241 //===----------------------------------------------------------------------===//
1242
verify()1243 LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() {
1244 return ::verifyAtomicCompareExchangeImpl(*this);
1245 }
1246
parse(OpAsmParser & parser,OperationState & result)1247 ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
1248 OperationState &result) {
1249 return ::parseAtomicCompareExchangeImpl(parser, result);
1250 }
print(OpAsmPrinter & p)1251 void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
1252 ::printAtomicCompareExchangeImpl(*this, p);
1253 }
1254
1255 //===----------------------------------------------------------------------===//
1256 // spv.AtomicExchange
1257 //===----------------------------------------------------------------------===//
1258
print(OpAsmPrinter & printer)1259 void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) {
1260 printer << " \"" << stringifyScope(memory_scope()) << "\" \""
1261 << stringifyMemorySemantics(semantics()) << "\" " << getOperands()
1262 << " : " << pointer().getType();
1263 }
1264
parse(OpAsmParser & parser,OperationState & state)1265 ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
1266 OperationState &state) {
1267 spirv::Scope memoryScope;
1268 spirv::MemorySemantics semantics;
1269 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
1270 Type type;
1271 if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1272 parseEnumStrAttr(semantics, parser, state, kSemanticsAttrName) ||
1273 parser.parseOperandList(operandInfo, 2))
1274 return failure();
1275
1276 auto loc = parser.getCurrentLocation();
1277 if (parser.parseColonType(type))
1278 return failure();
1279
1280 auto ptrType = type.dyn_cast<spirv::PointerType>();
1281 if (!ptrType)
1282 return parser.emitError(loc, "expected pointer type");
1283
1284 if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
1285 parser.getNameLoc(), state.operands))
1286 return failure();
1287
1288 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1289 }
1290
verify()1291 LogicalResult spirv::AtomicExchangeOp::verify() {
1292 if (getType() != value().getType())
1293 return emitOpError("value operand must have the same type as the op "
1294 "result, but found ")
1295 << value().getType() << " vs " << getType();
1296
1297 Type pointeeType =
1298 pointer().getType().cast<spirv::PointerType>().getPointeeType();
1299 if (getType() != pointeeType)
1300 return emitOpError("pointer operand's pointee type must have the same "
1301 "as the op result type, but found ")
1302 << pointeeType << " vs " << getType();
1303
1304 return success();
1305 }
1306
1307 //===----------------------------------------------------------------------===//
1308 // spv.AtomicIAddOp
1309 //===----------------------------------------------------------------------===//
1310
verify()1311 LogicalResult spirv::AtomicIAddOp::verify() {
1312 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1313 }
1314
parse(OpAsmParser & parser,OperationState & result)1315 ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser,
1316 OperationState &result) {
1317 return ::parseAtomicUpdateOp(parser, result, true);
1318 }
print(OpAsmPrinter & p)1319 void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
1320 ::printAtomicUpdateOp(*this, p);
1321 }
1322
1323 //===----------------------------------------------------------------------===//
1324 // spv.AtomicFAddEXTOp
1325 //===----------------------------------------------------------------------===//
1326
verify()1327 LogicalResult spirv::AtomicFAddEXTOp::verify() {
1328 return ::verifyAtomicUpdateOp<FloatType>(getOperation());
1329 }
1330
parse(OpAsmParser & parser,OperationState & result)1331 ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser,
1332 OperationState &result) {
1333 return ::parseAtomicUpdateOp(parser, result, true);
1334 }
print(OpAsmPrinter & p)1335 void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) {
1336 ::printAtomicUpdateOp(*this, p);
1337 }
1338
1339 //===----------------------------------------------------------------------===//
1340 // spv.AtomicIDecrementOp
1341 //===----------------------------------------------------------------------===//
1342
verify()1343 LogicalResult spirv::AtomicIDecrementOp::verify() {
1344 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1345 }
1346
parse(OpAsmParser & parser,OperationState & result)1347 ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser,
1348 OperationState &result) {
1349 return ::parseAtomicUpdateOp(parser, result, false);
1350 }
print(OpAsmPrinter & p)1351 void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) {
1352 ::printAtomicUpdateOp(*this, p);
1353 }
1354
1355 //===----------------------------------------------------------------------===//
1356 // spv.AtomicIIncrementOp
1357 //===----------------------------------------------------------------------===//
1358
verify()1359 LogicalResult spirv::AtomicIIncrementOp::verify() {
1360 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1361 }
1362
parse(OpAsmParser & parser,OperationState & result)1363 ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser,
1364 OperationState &result) {
1365 return ::parseAtomicUpdateOp(parser, result, false);
1366 }
print(OpAsmPrinter & p)1367 void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) {
1368 ::printAtomicUpdateOp(*this, p);
1369 }
1370
1371 //===----------------------------------------------------------------------===//
1372 // spv.AtomicISubOp
1373 //===----------------------------------------------------------------------===//
1374
verify()1375 LogicalResult spirv::AtomicISubOp::verify() {
1376 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1377 }
1378
parse(OpAsmParser & parser,OperationState & result)1379 ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser,
1380 OperationState &result) {
1381 return ::parseAtomicUpdateOp(parser, result, true);
1382 }
print(OpAsmPrinter & p)1383 void spirv::AtomicISubOp::print(OpAsmPrinter &p) {
1384 ::printAtomicUpdateOp(*this, p);
1385 }
1386
1387 //===----------------------------------------------------------------------===//
1388 // spv.AtomicOrOp
1389 //===----------------------------------------------------------------------===//
1390
verify()1391 LogicalResult spirv::AtomicOrOp::verify() {
1392 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1393 }
1394
parse(OpAsmParser & parser,OperationState & result)1395 ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser,
1396 OperationState &result) {
1397 return ::parseAtomicUpdateOp(parser, result, true);
1398 }
print(OpAsmPrinter & p)1399 void spirv::AtomicOrOp::print(OpAsmPrinter &p) {
1400 ::printAtomicUpdateOp(*this, p);
1401 }
1402
1403 //===----------------------------------------------------------------------===//
1404 // spv.AtomicSMaxOp
1405 //===----------------------------------------------------------------------===//
1406
verify()1407 LogicalResult spirv::AtomicSMaxOp::verify() {
1408 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1409 }
1410
parse(OpAsmParser & parser,OperationState & result)1411 ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser,
1412 OperationState &result) {
1413 return ::parseAtomicUpdateOp(parser, result, true);
1414 }
print(OpAsmPrinter & p)1415 void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) {
1416 ::printAtomicUpdateOp(*this, p);
1417 }
1418
1419 //===----------------------------------------------------------------------===//
1420 // spv.AtomicSMinOp
1421 //===----------------------------------------------------------------------===//
1422
verify()1423 LogicalResult spirv::AtomicSMinOp::verify() {
1424 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1425 }
1426
parse(OpAsmParser & parser,OperationState & result)1427 ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser,
1428 OperationState &result) {
1429 return ::parseAtomicUpdateOp(parser, result, true);
1430 }
print(OpAsmPrinter & p)1431 void spirv::AtomicSMinOp::print(OpAsmPrinter &p) {
1432 ::printAtomicUpdateOp(*this, p);
1433 }
1434
1435 //===----------------------------------------------------------------------===//
1436 // spv.AtomicUMaxOp
1437 //===----------------------------------------------------------------------===//
1438
verify()1439 LogicalResult spirv::AtomicUMaxOp::verify() {
1440 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1441 }
1442
parse(OpAsmParser & parser,OperationState & result)1443 ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser,
1444 OperationState &result) {
1445 return ::parseAtomicUpdateOp(parser, result, true);
1446 }
print(OpAsmPrinter & p)1447 void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) {
1448 ::printAtomicUpdateOp(*this, p);
1449 }
1450
1451 //===----------------------------------------------------------------------===//
1452 // spv.AtomicUMinOp
1453 //===----------------------------------------------------------------------===//
1454
verify()1455 LogicalResult spirv::AtomicUMinOp::verify() {
1456 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1457 }
1458
parse(OpAsmParser & parser,OperationState & result)1459 ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser,
1460 OperationState &result) {
1461 return ::parseAtomicUpdateOp(parser, result, true);
1462 }
print(OpAsmPrinter & p)1463 void spirv::AtomicUMinOp::print(OpAsmPrinter &p) {
1464 ::printAtomicUpdateOp(*this, p);
1465 }
1466
1467 //===----------------------------------------------------------------------===//
1468 // spv.AtomicXorOp
1469 //===----------------------------------------------------------------------===//
1470
verify()1471 LogicalResult spirv::AtomicXorOp::verify() {
1472 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1473 }
1474
parse(OpAsmParser & parser,OperationState & result)1475 ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser,
1476 OperationState &result) {
1477 return ::parseAtomicUpdateOp(parser, result, true);
1478 }
print(OpAsmPrinter & p)1479 void spirv::AtomicXorOp::print(OpAsmPrinter &p) {
1480 ::printAtomicUpdateOp(*this, p);
1481 }
1482
1483 //===----------------------------------------------------------------------===//
1484 // spv.BitcastOp
1485 //===----------------------------------------------------------------------===//
1486
verify()1487 LogicalResult spirv::BitcastOp::verify() {
1488 // TODO: The SPIR-V spec validation rules are different for different
1489 // versions.
1490 auto operandType = operand().getType();
1491 auto resultType = result().getType();
1492 if (operandType == resultType) {
1493 return emitError("result type must be different from operand type");
1494 }
1495 if (operandType.isa<spirv::PointerType>() &&
1496 !resultType.isa<spirv::PointerType>()) {
1497 return emitError(
1498 "unhandled bit cast conversion from pointer type to non-pointer type");
1499 }
1500 if (!operandType.isa<spirv::PointerType>() &&
1501 resultType.isa<spirv::PointerType>()) {
1502 return emitError(
1503 "unhandled bit cast conversion from non-pointer type to pointer type");
1504 }
1505 auto operandBitWidth = getBitWidth(operandType);
1506 auto resultBitWidth = getBitWidth(resultType);
1507 if (operandBitWidth != resultBitWidth) {
1508 return emitOpError("mismatch in result type bitwidth ")
1509 << resultBitWidth << " and operand type bitwidth "
1510 << operandBitWidth;
1511 }
1512 return success();
1513 }
1514
1515 //===----------------------------------------------------------------------===//
1516 // spv.BranchOp
1517 //===----------------------------------------------------------------------===//
1518
getSuccessorOperands(unsigned index)1519 SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
1520 assert(index == 0 && "invalid successor index");
1521 return SuccessorOperands(0, targetOperandsMutable());
1522 }
1523
1524 //===----------------------------------------------------------------------===//
1525 // spv.BranchConditionalOp
1526 //===----------------------------------------------------------------------===//
1527
1528 SuccessorOperands
getSuccessorOperands(unsigned index)1529 spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
1530 assert(index < 2 && "invalid successor index");
1531 return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable()
1532 : falseTargetOperandsMutable());
1533 }
1534
parse(OpAsmParser & parser,OperationState & state)1535 ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
1536 OperationState &state) {
1537 auto &builder = parser.getBuilder();
1538 OpAsmParser::UnresolvedOperand condInfo;
1539 Block *dest;
1540
1541 // Parse the condition.
1542 Type boolTy = builder.getI1Type();
1543 if (parser.parseOperand(condInfo) ||
1544 parser.resolveOperand(condInfo, boolTy, state.operands))
1545 return failure();
1546
1547 // Parse the optional branch weights.
1548 if (succeeded(parser.parseOptionalLSquare())) {
1549 IntegerAttr trueWeight, falseWeight;
1550 NamedAttrList weights;
1551
1552 auto i32Type = builder.getIntegerType(32);
1553 if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1554 parser.parseComma() ||
1555 parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1556 parser.parseRSquare())
1557 return failure();
1558
1559 state.addAttribute(kBranchWeightAttrName,
1560 builder.getArrayAttr({trueWeight, falseWeight}));
1561 }
1562
1563 // Parse the true branch.
1564 SmallVector<Value, 4> trueOperands;
1565 if (parser.parseComma() ||
1566 parser.parseSuccessorAndUseList(dest, trueOperands))
1567 return failure();
1568 state.addSuccessors(dest);
1569 state.addOperands(trueOperands);
1570
1571 // Parse the false branch.
1572 SmallVector<Value, 4> falseOperands;
1573 if (parser.parseComma() ||
1574 parser.parseSuccessorAndUseList(dest, falseOperands))
1575 return failure();
1576 state.addSuccessors(dest);
1577 state.addOperands(falseOperands);
1578 state.addAttribute(
1579 spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1580 builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
1581 static_cast<int32_t>(falseOperands.size())}));
1582
1583 return success();
1584 }
1585
print(OpAsmPrinter & printer)1586 void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
1587 printer << ' ' << condition();
1588
1589 if (auto weights = branch_weights()) {
1590 printer << " [";
1591 llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1592 printer << a.cast<IntegerAttr>().getInt();
1593 });
1594 printer << "]";
1595 }
1596
1597 printer << ", ";
1598 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
1599 printer << ", ";
1600 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
1601 }
1602
verify()1603 LogicalResult spirv::BranchConditionalOp::verify() {
1604 if (auto weights = branch_weights()) {
1605 if (weights->getValue().size() != 2) {
1606 return emitOpError("must have exactly two branch weights");
1607 }
1608 if (llvm::all_of(*weights, [](Attribute attr) {
1609 return attr.cast<IntegerAttr>().getValue().isNullValue();
1610 }))
1611 return emitOpError("branch weights cannot both be zero");
1612 }
1613
1614 return success();
1615 }
1616
1617 //===----------------------------------------------------------------------===//
1618 // spv.CompositeConstruct
1619 //===----------------------------------------------------------------------===//
1620
parse(OpAsmParser & parser,OperationState & state)1621 ParseResult spirv::CompositeConstructOp::parse(OpAsmParser &parser,
1622 OperationState &state) {
1623 SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1624 Type type;
1625 auto loc = parser.getCurrentLocation();
1626
1627 if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
1628 return failure();
1629 }
1630 auto cType = type.dyn_cast<spirv::CompositeType>();
1631 if (!cType) {
1632 return parser.emitError(
1633 loc, "result type must be a composite type, but provided ")
1634 << type;
1635 }
1636
1637 if (cType.hasCompileTimeKnownNumElements() &&
1638 operands.size() != cType.getNumElements()) {
1639 return parser.emitError(loc, "has incorrect number of operands: expected ")
1640 << cType.getNumElements() << ", but provided " << operands.size();
1641 }
1642 // TODO: Add support for constructing a vector type from the vector operands.
1643 // According to the spec: "for constructing a vector, the operands may
1644 // also be vectors with the same component type as the Result Type component
1645 // type".
1646 SmallVector<Type, 4> elementTypes;
1647 elementTypes.reserve(operands.size());
1648 for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
1649 elementTypes.push_back(cType.getElementType(index));
1650 }
1651 state.addTypes(type);
1652 return parser.resolveOperands(operands, elementTypes, loc, state.operands);
1653 }
1654
print(OpAsmPrinter & printer)1655 void spirv::CompositeConstructOp::print(OpAsmPrinter &printer) {
1656 printer << " " << constituents() << " : " << getResult().getType();
1657 }
1658
verify()1659 LogicalResult spirv::CompositeConstructOp::verify() {
1660 auto cType = getType().cast<spirv::CompositeType>();
1661 operand_range constituents = this->constituents();
1662
1663 if (cType.isa<spirv::CooperativeMatrixNVType>()) {
1664 if (constituents.size() != 1)
1665 return emitError("has incorrect number of operands: expected ")
1666 << "1, but provided " << constituents.size();
1667 } else if (constituents.size() != cType.getNumElements()) {
1668 return emitError("has incorrect number of operands: expected ")
1669 << cType.getNumElements() << ", but provided "
1670 << constituents.size();
1671 }
1672
1673 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1674 if (constituents[index].getType() != cType.getElementType(index)) {
1675 return emitError("operand type mismatch: expected operand type ")
1676 << cType.getElementType(index) << ", but provided "
1677 << constituents[index].getType();
1678 }
1679 }
1680
1681 return success();
1682 }
1683
1684 //===----------------------------------------------------------------------===//
1685 // spv.CompositeExtractOp
1686 //===----------------------------------------------------------------------===//
1687
build(OpBuilder & builder,OperationState & state,Value composite,ArrayRef<int32_t> indices)1688 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1689 Value composite,
1690 ArrayRef<int32_t> indices) {
1691 auto indexAttr = builder.getI32ArrayAttr(indices);
1692 auto elementType =
1693 getElementType(composite.getType(), indexAttr, state.location);
1694 if (!elementType) {
1695 return;
1696 }
1697 build(builder, state, elementType, composite, indexAttr);
1698 }
1699
parse(OpAsmParser & parser,OperationState & state)1700 ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
1701 OperationState &state) {
1702 OpAsmParser::UnresolvedOperand compositeInfo;
1703 Attribute indicesAttr;
1704 Type compositeType;
1705 SMLoc attrLocation;
1706
1707 if (parser.parseOperand(compositeInfo) ||
1708 parser.getCurrentLocation(&attrLocation) ||
1709 parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1710 parser.parseColonType(compositeType) ||
1711 parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
1712 return failure();
1713 }
1714
1715 Type resultType =
1716 getElementType(compositeType, indicesAttr, parser, attrLocation);
1717 if (!resultType) {
1718 return failure();
1719 }
1720 state.addTypes(resultType);
1721 return success();
1722 }
1723
print(OpAsmPrinter & printer)1724 void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
1725 printer << ' ' << composite() << indices() << " : " << composite().getType();
1726 }
1727
verify()1728 LogicalResult spirv::CompositeExtractOp::verify() {
1729 auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
1730 auto resultType =
1731 getElementType(composite().getType(), indicesArrayAttr, getLoc());
1732 if (!resultType)
1733 return failure();
1734
1735 if (resultType != getType()) {
1736 return emitOpError("invalid result type: expected ")
1737 << resultType << " but provided " << getType();
1738 }
1739
1740 return success();
1741 }
1742
1743 //===----------------------------------------------------------------------===//
1744 // spv.CompositeInsert
1745 //===----------------------------------------------------------------------===//
1746
build(OpBuilder & builder,OperationState & state,Value object,Value composite,ArrayRef<int32_t> indices)1747 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1748 Value object, Value composite,
1749 ArrayRef<int32_t> indices) {
1750 auto indexAttr = builder.getI32ArrayAttr(indices);
1751 build(builder, state, composite.getType(), object, composite, indexAttr);
1752 }
1753
parse(OpAsmParser & parser,OperationState & state)1754 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
1755 OperationState &state) {
1756 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
1757 Type objectType, compositeType;
1758 Attribute indicesAttr;
1759 auto loc = parser.getCurrentLocation();
1760
1761 return failure(
1762 parser.parseOperandList(operands, 2) ||
1763 parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1764 parser.parseColonType(objectType) ||
1765 parser.parseKeywordType("into", compositeType) ||
1766 parser.resolveOperands(operands, {objectType, compositeType}, loc,
1767 state.operands) ||
1768 parser.addTypesToList(compositeType, state.types));
1769 }
1770
verify()1771 LogicalResult spirv::CompositeInsertOp::verify() {
1772 auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
1773 auto objectType =
1774 getElementType(composite().getType(), indicesArrayAttr, getLoc());
1775 if (!objectType)
1776 return failure();
1777
1778 if (objectType != object().getType()) {
1779 return emitOpError("object operand type should be ")
1780 << objectType << ", but found " << object().getType();
1781 }
1782
1783 if (composite().getType() != getType()) {
1784 return emitOpError("result type should be the same as "
1785 "the composite type, but found ")
1786 << composite().getType() << " vs " << getType();
1787 }
1788
1789 return success();
1790 }
1791
print(OpAsmPrinter & printer)1792 void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
1793 printer << " " << object() << ", " << composite() << indices() << " : "
1794 << object().getType() << " into " << composite().getType();
1795 }
1796
1797 //===----------------------------------------------------------------------===//
1798 // spv.Constant
1799 //===----------------------------------------------------------------------===//
1800
parse(OpAsmParser & parser,OperationState & state)1801 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
1802 OperationState &state) {
1803 Attribute value;
1804 if (parser.parseAttribute(value, kValueAttrName, state.attributes))
1805 return failure();
1806
1807 Type type = value.getType();
1808 if (type.isa<NoneType, TensorType>()) {
1809 if (parser.parseColonType(type))
1810 return failure();
1811 }
1812
1813 return parser.addTypeToList(type, state.types);
1814 }
1815
print(OpAsmPrinter & printer)1816 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
1817 printer << ' ' << value();
1818 if (getType().isa<spirv::ArrayType>())
1819 printer << " : " << getType();
1820 }
1821
verifyConstantType(spirv::ConstantOp op,Attribute value,Type opType)1822 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
1823 Type opType) {
1824 auto valueType = value.getType();
1825
1826 if (value.isa<IntegerAttr, FloatAttr>()) {
1827 if (valueType != opType)
1828 return op.emitOpError("result type (")
1829 << opType << ") does not match value type (" << valueType << ")";
1830 return success();
1831 }
1832 if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1833 if (valueType == opType)
1834 return success();
1835 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1836 auto shapedType = valueType.dyn_cast<ShapedType>();
1837 if (!arrayType)
1838 return op.emitOpError("result or element type (")
1839 << opType << ") does not match value type (" << valueType
1840 << "), must be the same or spv.array";
1841
1842 int numElements = arrayType.getNumElements();
1843 auto opElemType = arrayType.getElementType();
1844 while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1845 numElements *= t.getNumElements();
1846 opElemType = t.getElementType();
1847 }
1848 if (!opElemType.isIntOrFloat())
1849 return op.emitOpError("only support nested array result type");
1850
1851 auto valueElemType = shapedType.getElementType();
1852 if (valueElemType != opElemType) {
1853 return op.emitOpError("result element type (")
1854 << opElemType << ") does not match value element type ("
1855 << valueElemType << ")";
1856 }
1857
1858 if (numElements != shapedType.getNumElements()) {
1859 return op.emitOpError("result number of elements (")
1860 << numElements << ") does not match value number of elements ("
1861 << shapedType.getNumElements() << ")";
1862 }
1863 return success();
1864 }
1865 if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
1866 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1867 if (!arrayType)
1868 return op.emitOpError("must have spv.array result type for array value");
1869 Type elemType = arrayType.getElementType();
1870 for (Attribute element : arrayAttr.getValue()) {
1871 // Verify array elements recursively.
1872 if (failed(verifyConstantType(op, element, elemType)))
1873 return failure();
1874 }
1875 return success();
1876 }
1877 return op.emitOpError("cannot have value of type ") << valueType;
1878 }
1879
verify()1880 LogicalResult spirv::ConstantOp::verify() {
1881 // ODS already generates checks to make sure the result type is valid. We just
1882 // need to additionally check that the value's attribute type is consistent
1883 // with the result type.
1884 return verifyConstantType(*this, valueAttr(), getType());
1885 }
1886
isBuildableWith(Type type)1887 bool spirv::ConstantOp::isBuildableWith(Type type) {
1888 // Must be valid SPIR-V type first.
1889 if (!type.isa<spirv::SPIRVType>())
1890 return false;
1891
1892 if (isa<SPIRVDialect>(type.getDialect())) {
1893 // TODO: support constant struct
1894 return type.isa<spirv::ArrayType>();
1895 }
1896
1897 return true;
1898 }
1899
getZero(Type type,Location loc,OpBuilder & builder)1900 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
1901 OpBuilder &builder) {
1902 if (auto intType = type.dyn_cast<IntegerType>()) {
1903 unsigned width = intType.getWidth();
1904 if (width == 1)
1905 return builder.create<spirv::ConstantOp>(loc, type,
1906 builder.getBoolAttr(false));
1907 return builder.create<spirv::ConstantOp>(
1908 loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
1909 }
1910 if (auto floatType = type.dyn_cast<FloatType>()) {
1911 return builder.create<spirv::ConstantOp>(
1912 loc, type, builder.getFloatAttr(floatType, 0.0));
1913 }
1914 if (auto vectorType = type.dyn_cast<VectorType>()) {
1915 Type elemType = vectorType.getElementType();
1916 if (elemType.isa<IntegerType>()) {
1917 return builder.create<spirv::ConstantOp>(
1918 loc, type,
1919 DenseElementsAttr::get(vectorType,
1920 IntegerAttr::get(elemType, 0.0).getValue()));
1921 }
1922 if (elemType.isa<FloatType>()) {
1923 return builder.create<spirv::ConstantOp>(
1924 loc, type,
1925 DenseFPElementsAttr::get(vectorType,
1926 FloatAttr::get(elemType, 0.0).getValue()));
1927 }
1928 }
1929
1930 llvm_unreachable("unimplemented types for ConstantOp::getZero()");
1931 }
1932
getOne(Type type,Location loc,OpBuilder & builder)1933 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
1934 OpBuilder &builder) {
1935 if (auto intType = type.dyn_cast<IntegerType>()) {
1936 unsigned width = intType.getWidth();
1937 if (width == 1)
1938 return builder.create<spirv::ConstantOp>(loc, type,
1939 builder.getBoolAttr(true));
1940 return builder.create<spirv::ConstantOp>(
1941 loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
1942 }
1943 if (auto floatType = type.dyn_cast<FloatType>()) {
1944 return builder.create<spirv::ConstantOp>(
1945 loc, type, builder.getFloatAttr(floatType, 1.0));
1946 }
1947 if (auto vectorType = type.dyn_cast<VectorType>()) {
1948 Type elemType = vectorType.getElementType();
1949 if (elemType.isa<IntegerType>()) {
1950 return builder.create<spirv::ConstantOp>(
1951 loc, type,
1952 DenseElementsAttr::get(vectorType,
1953 IntegerAttr::get(elemType, 1.0).getValue()));
1954 }
1955 if (elemType.isa<FloatType>()) {
1956 return builder.create<spirv::ConstantOp>(
1957 loc, type,
1958 DenseFPElementsAttr::get(vectorType,
1959 FloatAttr::get(elemType, 1.0).getValue()));
1960 }
1961 }
1962
1963 llvm_unreachable("unimplemented types for ConstantOp::getOne()");
1964 }
1965
getAsmResultNames(llvm::function_ref<void (mlir::Value,llvm::StringRef)> setNameFn)1966 void mlir::spirv::ConstantOp::getAsmResultNames(
1967 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
1968 Type type = getType();
1969
1970 SmallString<32> specialNameBuffer;
1971 llvm::raw_svector_ostream specialName(specialNameBuffer);
1972 specialName << "cst";
1973
1974 IntegerType intTy = type.dyn_cast<IntegerType>();
1975
1976 if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
1977 if (intTy && intTy.getWidth() == 1) {
1978 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1979 }
1980
1981 if (intTy.isSignless()) {
1982 specialName << intCst.getInt();
1983 } else {
1984 specialName << intCst.getSInt();
1985 }
1986 }
1987
1988 if (intTy || type.isa<FloatType>()) {
1989 specialName << '_' << type;
1990 }
1991
1992 if (auto vecType = type.dyn_cast<VectorType>()) {
1993 specialName << "_vec_";
1994 specialName << vecType.getDimSize(0);
1995
1996 Type elementType = vecType.getElementType();
1997
1998 if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
1999 specialName << "x" << elementType;
2000 }
2001 }
2002
2003 setNameFn(getResult(), specialName.str());
2004 }
2005
getAsmResultNames(llvm::function_ref<void (mlir::Value,llvm::StringRef)> setNameFn)2006 void mlir::spirv::AddressOfOp::getAsmResultNames(
2007 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2008 SmallString<32> specialNameBuffer;
2009 llvm::raw_svector_ostream specialName(specialNameBuffer);
2010 specialName << variable() << "_addr";
2011 setNameFn(getResult(), specialName.str());
2012 }
2013
2014 //===----------------------------------------------------------------------===//
2015 // spv.ControlBarrierOp
2016 //===----------------------------------------------------------------------===//
2017
verify()2018 LogicalResult spirv::ControlBarrierOp::verify() {
2019 return verifyMemorySemantics(getOperation(), memory_semantics());
2020 }
2021
2022 //===----------------------------------------------------------------------===//
2023 // spv.ConvertFToSOp
2024 //===----------------------------------------------------------------------===//
2025
verify()2026 LogicalResult spirv::ConvertFToSOp::verify() {
2027 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2028 /*skipBitWidthCheck=*/true);
2029 }
2030
2031 //===----------------------------------------------------------------------===//
2032 // spv.ConvertFToUOp
2033 //===----------------------------------------------------------------------===//
2034
verify()2035 LogicalResult spirv::ConvertFToUOp::verify() {
2036 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2037 /*skipBitWidthCheck=*/true);
2038 }
2039
2040 //===----------------------------------------------------------------------===//
2041 // spv.ConvertSToFOp
2042 //===----------------------------------------------------------------------===//
2043
verify()2044 LogicalResult spirv::ConvertSToFOp::verify() {
2045 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2046 /*skipBitWidthCheck=*/true);
2047 }
2048
2049 //===----------------------------------------------------------------------===//
2050 // spv.ConvertUToFOp
2051 //===----------------------------------------------------------------------===//
2052
verify()2053 LogicalResult spirv::ConvertUToFOp::verify() {
2054 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2055 /*skipBitWidthCheck=*/true);
2056 }
2057
2058 //===----------------------------------------------------------------------===//
2059 // spv.EntryPoint
2060 //===----------------------------------------------------------------------===//
2061
build(OpBuilder & builder,OperationState & state,spirv::ExecutionModel executionModel,spirv::FuncOp function,ArrayRef<Attribute> interfaceVars)2062 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
2063 spirv::ExecutionModel executionModel,
2064 spirv::FuncOp function,
2065 ArrayRef<Attribute> interfaceVars) {
2066 build(builder, state,
2067 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
2068 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
2069 }
2070
parse(OpAsmParser & parser,OperationState & state)2071 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
2072 OperationState &state) {
2073 spirv::ExecutionModel execModel;
2074 SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers;
2075 SmallVector<Type, 0> idTypes;
2076 SmallVector<Attribute, 4> interfaceVars;
2077
2078 FlatSymbolRefAttr fn;
2079 if (parseEnumStrAttr(execModel, parser, state) ||
2080 parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
2081 return failure();
2082 }
2083
2084 if (!parser.parseOptionalComma()) {
2085 // Parse the interface variables
2086 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
2087 // The name of the interface variable attribute isnt important
2088 FlatSymbolRefAttr var;
2089 NamedAttrList attrs;
2090 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
2091 return failure();
2092 interfaceVars.push_back(var);
2093 return success();
2094 }))
2095 return failure();
2096 }
2097 state.addAttribute(kInterfaceAttrName,
2098 parser.getBuilder().getArrayAttr(interfaceVars));
2099 return success();
2100 }
2101
print(OpAsmPrinter & printer)2102 void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
2103 printer << " \"" << stringifyExecutionModel(execution_model()) << "\" ";
2104 printer.printSymbolName(fn());
2105 auto interfaceVars = interface().getValue();
2106 if (!interfaceVars.empty()) {
2107 printer << ", ";
2108 llvm::interleaveComma(interfaceVars, printer);
2109 }
2110 }
2111
verify()2112 LogicalResult spirv::EntryPointOp::verify() {
2113 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
2114 // verification.
2115 return success();
2116 }
2117
2118 //===----------------------------------------------------------------------===//
2119 // spv.ExecutionMode
2120 //===----------------------------------------------------------------------===//
2121
build(OpBuilder & builder,OperationState & state,spirv::FuncOp function,spirv::ExecutionMode executionMode,ArrayRef<int32_t> params)2122 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
2123 spirv::FuncOp function,
2124 spirv::ExecutionMode executionMode,
2125 ArrayRef<int32_t> params) {
2126 build(builder, state, SymbolRefAttr::get(function),
2127 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
2128 builder.getI32ArrayAttr(params));
2129 }
2130
parse(OpAsmParser & parser,OperationState & state)2131 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
2132 OperationState &state) {
2133 spirv::ExecutionMode execMode;
2134 Attribute fn;
2135 if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
2136 parseEnumStrAttr(execMode, parser, state)) {
2137 return failure();
2138 }
2139
2140 SmallVector<int32_t, 4> values;
2141 Type i32Type = parser.getBuilder().getIntegerType(32);
2142 while (!parser.parseOptionalComma()) {
2143 NamedAttrList attr;
2144 Attribute value;
2145 if (parser.parseAttribute(value, i32Type, "value", attr)) {
2146 return failure();
2147 }
2148 values.push_back(value.cast<IntegerAttr>().getInt());
2149 }
2150 state.addAttribute(kValuesAttrName,
2151 parser.getBuilder().getI32ArrayAttr(values));
2152 return success();
2153 }
2154
print(OpAsmPrinter & printer)2155 void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
2156 printer << " ";
2157 printer.printSymbolName(fn());
2158 printer << " \"" << stringifyExecutionMode(execution_mode()) << "\"";
2159 auto values = this->values();
2160 if (values.empty())
2161 return;
2162 printer << ", ";
2163 llvm::interleaveComma(values, printer, [&](Attribute a) {
2164 printer << a.cast<IntegerAttr>().getInt();
2165 });
2166 }
2167
2168 //===----------------------------------------------------------------------===//
2169 // spv.FConvertOp
2170 //===----------------------------------------------------------------------===//
2171
verify()2172 LogicalResult spirv::FConvertOp::verify() {
2173 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2174 }
2175
2176 //===----------------------------------------------------------------------===//
2177 // spv.SConvertOp
2178 //===----------------------------------------------------------------------===//
2179
verify()2180 LogicalResult spirv::SConvertOp::verify() {
2181 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2182 }
2183
2184 //===----------------------------------------------------------------------===//
2185 // spv.UConvertOp
2186 //===----------------------------------------------------------------------===//
2187
verify()2188 LogicalResult spirv::UConvertOp::verify() {
2189 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2190 }
2191
2192 //===----------------------------------------------------------------------===//
2193 // spv.func
2194 //===----------------------------------------------------------------------===//
2195
parse(OpAsmParser & parser,OperationState & state)2196 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
2197 SmallVector<OpAsmParser::Argument> entryArgs;
2198 SmallVector<DictionaryAttr> resultAttrs;
2199 SmallVector<Type> resultTypes;
2200 auto &builder = parser.getBuilder();
2201
2202 // Parse the name as a symbol.
2203 StringAttr nameAttr;
2204 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2205 state.attributes))
2206 return failure();
2207
2208 // Parse the function signature.
2209 bool isVariadic = false;
2210 if (function_interface_impl::parseFunctionSignature(
2211 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
2212 resultAttrs))
2213 return failure();
2214
2215 SmallVector<Type> argTypes;
2216 for (auto &arg : entryArgs)
2217 argTypes.push_back(arg.type);
2218 auto fnType = builder.getFunctionType(argTypes, resultTypes);
2219 state.addAttribute(FunctionOpInterface::getTypeAttrName(),
2220 TypeAttr::get(fnType));
2221
2222 // Parse the optional function control keyword.
2223 spirv::FunctionControl fnControl;
2224 if (parseEnumStrAttr(fnControl, parser, state))
2225 return failure();
2226
2227 // If additional attributes are present, parse them.
2228 if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
2229 return failure();
2230
2231 // Add the attributes to the function arguments.
2232 assert(resultAttrs.size() == resultTypes.size());
2233 function_interface_impl::addArgAndResultAttrs(builder, state, entryArgs,
2234 resultAttrs);
2235
2236 // Parse the optional function body.
2237 auto *body = state.addRegion();
2238 OptionalParseResult result = parser.parseOptionalRegion(*body, entryArgs);
2239 return failure(result.hasValue() && failed(*result));
2240 }
2241
print(OpAsmPrinter & printer)2242 void spirv::FuncOp::print(OpAsmPrinter &printer) {
2243 // Print function name, signature, and control.
2244 printer << " ";
2245 printer.printSymbolName(sym_name());
2246 auto fnType = getFunctionType();
2247 function_interface_impl::printFunctionSignature(
2248 printer, *this, fnType.getInputs(),
2249 /*isVariadic=*/false, fnType.getResults());
2250 printer << " \"" << spirv::stringifyFunctionControl(function_control())
2251 << "\"";
2252 function_interface_impl::printFunctionAttributes(
2253 printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
2254 {spirv::attributeName<spirv::FunctionControl>()});
2255
2256 // Print the body if this is not an external function.
2257 Region &body = this->body();
2258 if (!body.empty()) {
2259 printer << ' ';
2260 printer.printRegion(body, /*printEntryBlockArgs=*/false,
2261 /*printBlockTerminators=*/true);
2262 }
2263 }
2264
verifyType()2265 LogicalResult spirv::FuncOp::verifyType() {
2266 auto type = getFunctionTypeAttr().getValue();
2267 if (!type.isa<FunctionType>())
2268 return emitOpError("requires '" + getTypeAttrName() +
2269 "' attribute of function type");
2270 if (getFunctionType().getNumResults() > 1)
2271 return emitOpError("cannot have more than one result");
2272 return success();
2273 }
2274
verifyBody()2275 LogicalResult spirv::FuncOp::verifyBody() {
2276 FunctionType fnType = getFunctionType();
2277
2278 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
2279 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
2280 if (fnType.getNumResults() != 0)
2281 return retOp.emitOpError("cannot be used in functions returning value");
2282 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
2283 if (fnType.getNumResults() != 1)
2284 return retOp.emitOpError(
2285 "returns 1 value but enclosing function requires ")
2286 << fnType.getNumResults() << " results";
2287
2288 auto retOperandType = retOp.value().getType();
2289 auto fnResultType = fnType.getResult(0);
2290 if (retOperandType != fnResultType)
2291 return retOp.emitOpError(" return value's type (")
2292 << retOperandType << ") mismatch with function's result type ("
2293 << fnResultType << ")";
2294 }
2295 return WalkResult::advance();
2296 });
2297
2298 // TODO: verify other bits like linkage type.
2299
2300 return failure(walkResult.wasInterrupted());
2301 }
2302
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,spirv::FunctionControl control,ArrayRef<NamedAttribute> attrs)2303 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
2304 StringRef name, FunctionType type,
2305 spirv::FunctionControl control,
2306 ArrayRef<NamedAttribute> attrs) {
2307 state.addAttribute(SymbolTable::getSymbolAttrName(),
2308 builder.getStringAttr(name));
2309 state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
2310 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
2311 builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
2312 state.attributes.append(attrs.begin(), attrs.end());
2313 state.addRegion();
2314 }
2315
2316 // CallableOpInterface
getCallableRegion()2317 Region *spirv::FuncOp::getCallableRegion() {
2318 return isExternal() ? nullptr : &body();
2319 }
2320
2321 // CallableOpInterface
getCallableResults()2322 ArrayRef<Type> spirv::FuncOp::getCallableResults() {
2323 return getFunctionType().getResults();
2324 }
2325
2326 //===----------------------------------------------------------------------===//
2327 // spv.FunctionCall
2328 //===----------------------------------------------------------------------===//
2329
verify()2330 LogicalResult spirv::FunctionCallOp::verify() {
2331 auto fnName = calleeAttr();
2332
2333 auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
2334 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
2335 if (!funcOp) {
2336 return emitOpError("callee function '")
2337 << fnName.getValue() << "' not found in nearest symbol table";
2338 }
2339
2340 auto functionType = funcOp.getFunctionType();
2341
2342 if (getNumResults() > 1) {
2343 return emitOpError(
2344 "expected callee function to have 0 or 1 result, but provided ")
2345 << getNumResults();
2346 }
2347
2348 if (functionType.getNumInputs() != getNumOperands()) {
2349 return emitOpError("has incorrect number of operands for callee: expected ")
2350 << functionType.getNumInputs() << ", but provided "
2351 << getNumOperands();
2352 }
2353
2354 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
2355 if (getOperand(i).getType() != functionType.getInput(i)) {
2356 return emitOpError("operand type mismatch: expected operand type ")
2357 << functionType.getInput(i) << ", but provided "
2358 << getOperand(i).getType() << " for operand number " << i;
2359 }
2360 }
2361
2362 if (functionType.getNumResults() != getNumResults()) {
2363 return emitOpError(
2364 "has incorrect number of results has for callee: expected ")
2365 << functionType.getNumResults() << ", but provided "
2366 << getNumResults();
2367 }
2368
2369 if (getNumResults() &&
2370 (getResult(0).getType() != functionType.getResult(0))) {
2371 return emitOpError("result type mismatch: expected ")
2372 << functionType.getResult(0) << ", but provided "
2373 << getResult(0).getType();
2374 }
2375
2376 return success();
2377 }
2378
getCallableForCallee()2379 CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2380 return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2381 }
2382
getArgOperands()2383 Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2384 return arguments();
2385 }
2386
2387 //===----------------------------------------------------------------------===//
2388 // spv.GLFClampOp
2389 //===----------------------------------------------------------------------===//
2390
parse(OpAsmParser & parser,OperationState & result)2391 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
2392 OperationState &result) {
2393 return parseOneResultSameOperandTypeOp(parser, result);
2394 }
print(OpAsmPrinter & p)2395 void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2396
2397 //===----------------------------------------------------------------------===//
2398 // spv.GLUClampOp
2399 //===----------------------------------------------------------------------===//
2400
parse(OpAsmParser & parser,OperationState & result)2401 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
2402 OperationState &result) {
2403 return parseOneResultSameOperandTypeOp(parser, result);
2404 }
print(OpAsmPrinter & p)2405 void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2406
2407 //===----------------------------------------------------------------------===//
2408 // spv.GLSClampOp
2409 //===----------------------------------------------------------------------===//
2410
parse(OpAsmParser & parser,OperationState & result)2411 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
2412 OperationState &result) {
2413 return parseOneResultSameOperandTypeOp(parser, result);
2414 }
print(OpAsmPrinter & p)2415 void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2416
2417 //===----------------------------------------------------------------------===//
2418 // spv.GLFmaOp
2419 //===----------------------------------------------------------------------===//
2420
parse(OpAsmParser & parser,OperationState & result)2421 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
2422 return parseOneResultSameOperandTypeOp(parser, result);
2423 }
print(OpAsmPrinter & p)2424 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2425
2426 //===----------------------------------------------------------------------===//
2427 // spv.GlobalVariable
2428 //===----------------------------------------------------------------------===//
2429
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,unsigned descriptorSet,unsigned binding)2430 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2431 Type type, StringRef name,
2432 unsigned descriptorSet, unsigned binding) {
2433 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2434 state.addAttribute(
2435 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2436 builder.getI32IntegerAttr(descriptorSet));
2437 state.addAttribute(
2438 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2439 builder.getI32IntegerAttr(binding));
2440 }
2441
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,spirv::BuiltIn builtin)2442 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2443 Type type, StringRef name,
2444 spirv::BuiltIn builtin) {
2445 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2446 state.addAttribute(
2447 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2448 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2449 }
2450
parse(OpAsmParser & parser,OperationState & state)2451 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
2452 OperationState &state) {
2453 // Parse variable name.
2454 StringAttr nameAttr;
2455 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2456 state.attributes)) {
2457 return failure();
2458 }
2459
2460 // Parse optional initializer
2461 if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
2462 FlatSymbolRefAttr initSymbol;
2463 if (parser.parseLParen() ||
2464 parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2465 state.attributes) ||
2466 parser.parseRParen())
2467 return failure();
2468 }
2469
2470 if (parseVariableDecorations(parser, state)) {
2471 return failure();
2472 }
2473
2474 Type type;
2475 auto loc = parser.getCurrentLocation();
2476 if (parser.parseColonType(type)) {
2477 return failure();
2478 }
2479 if (!type.isa<spirv::PointerType>()) {
2480 return parser.emitError(loc, "expected spv.ptr type");
2481 }
2482 state.addAttribute(kTypeAttrName, TypeAttr::get(type));
2483
2484 return success();
2485 }
2486
print(OpAsmPrinter & printer)2487 void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
2488 SmallVector<StringRef, 4> elidedAttrs{
2489 spirv::attributeName<spirv::StorageClass>()};
2490
2491 // Print variable name.
2492 printer << ' ';
2493 printer.printSymbolName(sym_name());
2494 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2495
2496 // Print optional initializer
2497 if (auto initializer = this->initializer()) {
2498 printer << " " << kInitializerAttrName << '(';
2499 printer.printSymbolName(*initializer);
2500 printer << ')';
2501 elidedAttrs.push_back(kInitializerAttrName);
2502 }
2503
2504 elidedAttrs.push_back(kTypeAttrName);
2505 printVariableDecorations(*this, printer, elidedAttrs);
2506 printer << " : " << type();
2507 }
2508
verify()2509 LogicalResult spirv::GlobalVariableOp::verify() {
2510 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2511 // object. It cannot be Generic. It must be the same as the Storage Class
2512 // operand of the Result Type."
2513 // Also, Function storage class is reserved by spv.Variable.
2514 auto storageClass = this->storageClass();
2515 if (storageClass == spirv::StorageClass::Generic ||
2516 storageClass == spirv::StorageClass::Function) {
2517 return emitOpError("storage class cannot be '")
2518 << stringifyStorageClass(storageClass) << "'";
2519 }
2520
2521 if (auto init =
2522 (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2523 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
2524 (*this)->getParentOp(), init.getAttr());
2525 // TODO: Currently only variable initialization with specialization
2526 // constants and other variables is supported. They could be normal
2527 // constants in the module scope as well.
2528 if (!initOp ||
2529 !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2530 return emitOpError("initializer must be result of a "
2531 "spv.SpecConstant or spv.GlobalVariable op");
2532 }
2533 }
2534
2535 return success();
2536 }
2537
2538 //===----------------------------------------------------------------------===//
2539 // spv.GroupBroadcast
2540 //===----------------------------------------------------------------------===//
2541
verify()2542 LogicalResult spirv::GroupBroadcastOp::verify() {
2543 spirv::Scope scope = execution_scope();
2544 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2545 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2546
2547 if (auto localIdTy = localid().getType().dyn_cast<VectorType>())
2548 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
2549 return emitOpError("localid is a vector and can be with only "
2550 " 2 or 3 components, actual number is ")
2551 << localIdTy.getNumElements();
2552
2553 return success();
2554 }
2555
2556 //===----------------------------------------------------------------------===//
2557 // spv.GroupNonUniformBallotOp
2558 //===----------------------------------------------------------------------===//
2559
verify()2560 LogicalResult spirv::GroupNonUniformBallotOp::verify() {
2561 spirv::Scope scope = execution_scope();
2562 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2563 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2564
2565 return success();
2566 }
2567
2568 //===----------------------------------------------------------------------===//
2569 // spv.GroupNonUniformBroadcast
2570 //===----------------------------------------------------------------------===//
2571
verify()2572 LogicalResult spirv::GroupNonUniformBroadcastOp::verify() {
2573 spirv::Scope scope = execution_scope();
2574 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2575 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2576
2577 // SPIR-V spec: "Before version 1.5, Id must come from a
2578 // constant instruction.
2579 auto targetEnv = spirv::getDefaultTargetEnv(getContext());
2580 if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
2581 targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2582
2583 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2584 auto *idOp = id().getDefiningOp();
2585 if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2586 spirv::ReferenceOfOp>(idOp)) // for spec constant
2587 return emitOpError("id must be the result of a constant op");
2588 }
2589
2590 return success();
2591 }
2592
2593 //===----------------------------------------------------------------------===//
2594 // spv.SubgroupBlockReadINTEL
2595 //===----------------------------------------------------------------------===//
2596
parse(OpAsmParser & parser,OperationState & state)2597 ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
2598 OperationState &state) {
2599 // Parse the storage class specification
2600 spirv::StorageClass storageClass;
2601 OpAsmParser::UnresolvedOperand ptrInfo;
2602 Type elementType;
2603 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2604 parser.parseColon() || parser.parseType(elementType)) {
2605 return failure();
2606 }
2607
2608 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2609 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2610 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2611
2612 if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2613 return failure();
2614 }
2615
2616 state.addTypes(elementType);
2617 return success();
2618 }
2619
print(OpAsmPrinter & printer)2620 void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) {
2621 printer << " " << ptr() << " : " << getType();
2622 }
2623
verify()2624 LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
2625 if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
2626 return failure();
2627
2628 return success();
2629 }
2630
2631 //===----------------------------------------------------------------------===//
2632 // spv.SubgroupBlockWriteINTEL
2633 //===----------------------------------------------------------------------===//
2634
parse(OpAsmParser & parser,OperationState & state)2635 ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
2636 OperationState &state) {
2637 // Parse the storage class specification
2638 spirv::StorageClass storageClass;
2639 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
2640 auto loc = parser.getCurrentLocation();
2641 Type elementType;
2642 if (parseEnumStrAttr(storageClass, parser) ||
2643 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2644 parser.parseType(elementType)) {
2645 return failure();
2646 }
2647
2648 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2649 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2650 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2651
2652 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2653 state.operands)) {
2654 return failure();
2655 }
2656 return success();
2657 }
2658
print(OpAsmPrinter & printer)2659 void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) {
2660 printer << " " << ptr() << ", " << value() << " : " << value().getType();
2661 }
2662
verify()2663 LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() {
2664 if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
2665 return failure();
2666
2667 return success();
2668 }
2669
2670 //===----------------------------------------------------------------------===//
2671 // spv.GroupNonUniformElectOp
2672 //===----------------------------------------------------------------------===//
2673
verify()2674 LogicalResult spirv::GroupNonUniformElectOp::verify() {
2675 spirv::Scope scope = execution_scope();
2676 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2677 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2678
2679 return success();
2680 }
2681
2682 //===----------------------------------------------------------------------===//
2683 // spv.GroupNonUniformFAddOp
2684 //===----------------------------------------------------------------------===//
2685
verify()2686 LogicalResult spirv::GroupNonUniformFAddOp::verify() {
2687 return verifyGroupNonUniformArithmeticOp(*this);
2688 }
2689
parse(OpAsmParser & parser,OperationState & result)2690 ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser,
2691 OperationState &result) {
2692 return parseGroupNonUniformArithmeticOp(parser, result);
2693 }
print(OpAsmPrinter & p)2694 void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
2695 printGroupNonUniformArithmeticOp(*this, p);
2696 }
2697
2698 //===----------------------------------------------------------------------===//
2699 // spv.GroupNonUniformFMaxOp
2700 //===----------------------------------------------------------------------===//
2701
verify()2702 LogicalResult spirv::GroupNonUniformFMaxOp::verify() {
2703 return verifyGroupNonUniformArithmeticOp(*this);
2704 }
2705
parse(OpAsmParser & parser,OperationState & result)2706 ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
2707 OperationState &result) {
2708 return parseGroupNonUniformArithmeticOp(parser, result);
2709 }
print(OpAsmPrinter & p)2710 void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
2711 printGroupNonUniformArithmeticOp(*this, p);
2712 }
2713
2714 //===----------------------------------------------------------------------===//
2715 // spv.GroupNonUniformFMinOp
2716 //===----------------------------------------------------------------------===//
2717
verify()2718 LogicalResult spirv::GroupNonUniformFMinOp::verify() {
2719 return verifyGroupNonUniformArithmeticOp(*this);
2720 }
2721
parse(OpAsmParser & parser,OperationState & result)2722 ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser,
2723 OperationState &result) {
2724 return parseGroupNonUniformArithmeticOp(parser, result);
2725 }
print(OpAsmPrinter & p)2726 void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
2727 printGroupNonUniformArithmeticOp(*this, p);
2728 }
2729
2730 //===----------------------------------------------------------------------===//
2731 // spv.GroupNonUniformFMulOp
2732 //===----------------------------------------------------------------------===//
2733
verify()2734 LogicalResult spirv::GroupNonUniformFMulOp::verify() {
2735 return verifyGroupNonUniformArithmeticOp(*this);
2736 }
2737
parse(OpAsmParser & parser,OperationState & result)2738 ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser,
2739 OperationState &result) {
2740 return parseGroupNonUniformArithmeticOp(parser, result);
2741 }
print(OpAsmPrinter & p)2742 void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
2743 printGroupNonUniformArithmeticOp(*this, p);
2744 }
2745
2746 //===----------------------------------------------------------------------===//
2747 // spv.GroupNonUniformIAddOp
2748 //===----------------------------------------------------------------------===//
2749
verify()2750 LogicalResult spirv::GroupNonUniformIAddOp::verify() {
2751 return verifyGroupNonUniformArithmeticOp(*this);
2752 }
2753
parse(OpAsmParser & parser,OperationState & result)2754 ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser,
2755 OperationState &result) {
2756 return parseGroupNonUniformArithmeticOp(parser, result);
2757 }
print(OpAsmPrinter & p)2758 void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
2759 printGroupNonUniformArithmeticOp(*this, p);
2760 }
2761
2762 //===----------------------------------------------------------------------===//
2763 // spv.GroupNonUniformIMulOp
2764 //===----------------------------------------------------------------------===//
2765
verify()2766 LogicalResult spirv::GroupNonUniformIMulOp::verify() {
2767 return verifyGroupNonUniformArithmeticOp(*this);
2768 }
2769
parse(OpAsmParser & parser,OperationState & result)2770 ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser,
2771 OperationState &result) {
2772 return parseGroupNonUniformArithmeticOp(parser, result);
2773 }
print(OpAsmPrinter & p)2774 void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
2775 printGroupNonUniformArithmeticOp(*this, p);
2776 }
2777
2778 //===----------------------------------------------------------------------===//
2779 // spv.GroupNonUniformSMaxOp
2780 //===----------------------------------------------------------------------===//
2781
verify()2782 LogicalResult spirv::GroupNonUniformSMaxOp::verify() {
2783 return verifyGroupNonUniformArithmeticOp(*this);
2784 }
2785
parse(OpAsmParser & parser,OperationState & result)2786 ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
2787 OperationState &result) {
2788 return parseGroupNonUniformArithmeticOp(parser, result);
2789 }
print(OpAsmPrinter & p)2790 void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
2791 printGroupNonUniformArithmeticOp(*this, p);
2792 }
2793
2794 //===----------------------------------------------------------------------===//
2795 // spv.GroupNonUniformSMinOp
2796 //===----------------------------------------------------------------------===//
2797
verify()2798 LogicalResult spirv::GroupNonUniformSMinOp::verify() {
2799 return verifyGroupNonUniformArithmeticOp(*this);
2800 }
2801
parse(OpAsmParser & parser,OperationState & result)2802 ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser,
2803 OperationState &result) {
2804 return parseGroupNonUniformArithmeticOp(parser, result);
2805 }
print(OpAsmPrinter & p)2806 void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
2807 printGroupNonUniformArithmeticOp(*this, p);
2808 }
2809
2810 //===----------------------------------------------------------------------===//
2811 // spv.GroupNonUniformUMaxOp
2812 //===----------------------------------------------------------------------===//
2813
verify()2814 LogicalResult spirv::GroupNonUniformUMaxOp::verify() {
2815 return verifyGroupNonUniformArithmeticOp(*this);
2816 }
2817
parse(OpAsmParser & parser,OperationState & result)2818 ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
2819 OperationState &result) {
2820 return parseGroupNonUniformArithmeticOp(parser, result);
2821 }
print(OpAsmPrinter & p)2822 void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
2823 printGroupNonUniformArithmeticOp(*this, p);
2824 }
2825
2826 //===----------------------------------------------------------------------===//
2827 // spv.GroupNonUniformUMinOp
2828 //===----------------------------------------------------------------------===//
2829
verify()2830 LogicalResult spirv::GroupNonUniformUMinOp::verify() {
2831 return verifyGroupNonUniformArithmeticOp(*this);
2832 }
2833
parse(OpAsmParser & parser,OperationState & result)2834 ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser,
2835 OperationState &result) {
2836 return parseGroupNonUniformArithmeticOp(parser, result);
2837 }
print(OpAsmPrinter & p)2838 void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
2839 printGroupNonUniformArithmeticOp(*this, p);
2840 }
2841
2842 //===----------------------------------------------------------------------===//
2843 // spv.ISubBorrowOp
2844 //===----------------------------------------------------------------------===//
2845
verify()2846 LogicalResult spirv::ISubBorrowOp::verify() {
2847 auto resultType = getType().cast<spirv::StructType>();
2848 if (resultType.getNumElements() != 2)
2849 return emitOpError("expected result struct type containing two members");
2850
2851 SmallVector<Type, 4> types;
2852 types.push_back(operand1().getType());
2853 types.push_back(operand2().getType());
2854 types.push_back(resultType.getElementType(0));
2855 types.push_back(resultType.getElementType(1));
2856 if (!llvm::is_splat(types))
2857 return emitOpError(
2858 "expected all operand types and struct member types are the same");
2859
2860 return success();
2861 }
2862
parse(OpAsmParser & parser,OperationState & state)2863 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
2864 OperationState &state) {
2865 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
2866 if (parser.parseOptionalAttrDict(state.attributes) ||
2867 parser.parseOperandList(operands) || parser.parseColon())
2868 return failure();
2869
2870 Type resultType;
2871 auto loc = parser.getCurrentLocation();
2872 if (parser.parseType(resultType))
2873 return failure();
2874
2875 auto structType = resultType.dyn_cast<spirv::StructType>();
2876 if (!structType || structType.getNumElements() != 2)
2877 return parser.emitError(loc, "expected spv.struct type with two members");
2878
2879 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
2880 if (parser.resolveOperands(operands, operandTypes, loc, state.operands))
2881 return failure();
2882
2883 state.addTypes(resultType);
2884 return success();
2885 }
2886
print(OpAsmPrinter & printer)2887 void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
2888 printer << ' ';
2889 printer.printOptionalAttrDict((*this)->getAttrs());
2890 printer.printOperands((*this)->getOperands());
2891 printer << " : " << getType();
2892 }
2893
2894 //===----------------------------------------------------------------------===//
2895 // spv.LoadOp
2896 //===----------------------------------------------------------------------===//
2897
build(OpBuilder & builder,OperationState & state,Value basePtr,MemoryAccessAttr memoryAccess,IntegerAttr alignment)2898 void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
2899 Value basePtr, MemoryAccessAttr memoryAccess,
2900 IntegerAttr alignment) {
2901 auto ptrType = basePtr.getType().cast<spirv::PointerType>();
2902 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
2903 alignment);
2904 }
2905
parse(OpAsmParser & parser,OperationState & state)2906 ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &state) {
2907 // Parse the storage class specification
2908 spirv::StorageClass storageClass;
2909 OpAsmParser::UnresolvedOperand ptrInfo;
2910 Type elementType;
2911 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2912 parseMemoryAccessAttributes(parser, state) ||
2913 parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
2914 parser.parseType(elementType)) {
2915 return failure();
2916 }
2917
2918 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2919 if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2920 return failure();
2921 }
2922
2923 state.addTypes(elementType);
2924 return success();
2925 }
2926
print(OpAsmPrinter & printer)2927 void spirv::LoadOp::print(OpAsmPrinter &printer) {
2928 SmallVector<StringRef, 4> elidedAttrs;
2929 StringRef sc = stringifyStorageClass(
2930 ptr().getType().cast<spirv::PointerType>().getStorageClass());
2931 printer << " \"" << sc << "\" " << ptr();
2932
2933 printMemoryAccessAttribute(*this, printer, elidedAttrs);
2934
2935 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
2936 printer << " : " << getType();
2937 }
2938
verify()2939 LogicalResult spirv::LoadOp::verify() {
2940 // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
2941 // type with fixed size; i.e., it cannot be, nor include, any
2942 // OpTypeRuntimeArray types."
2943 if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) {
2944 return failure();
2945 }
2946 return verifyMemoryAccessAttribute(*this);
2947 }
2948
2949 //===----------------------------------------------------------------------===//
2950 // spv.mlir.loop
2951 //===----------------------------------------------------------------------===//
2952
build(OpBuilder & builder,OperationState & state)2953 void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
2954 state.addAttribute("loop_control",
2955 builder.getI32IntegerAttr(
2956 static_cast<uint32_t>(spirv::LoopControl::None)));
2957 state.addRegion();
2958 }
2959
parse(OpAsmParser & parser,OperationState & state)2960 ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &state) {
2961 if (parseControlAttribute<spirv::LoopControl>(parser, state))
2962 return failure();
2963 return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2964 /*argTypes=*/{});
2965 }
2966
print(OpAsmPrinter & printer)2967 void spirv::LoopOp::print(OpAsmPrinter &printer) {
2968 auto control = loop_control();
2969 if (control != spirv::LoopControl::None)
2970 printer << " control(" << spirv::stringifyLoopControl(control) << ")";
2971 printer << ' ';
2972 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2973 /*printBlockTerminators=*/true);
2974 }
2975
2976 /// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
2977 /// given `dstBlock`.
hasOneBranchOpTo(Block & srcBlock,Block & dstBlock)2978 static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
2979 // Check that there is only one op in the `srcBlock`.
2980 if (!llvm::hasSingleElement(srcBlock))
2981 return false;
2982
2983 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
2984 return branchOp && branchOp.getSuccessor() == &dstBlock;
2985 }
2986
verifyRegions()2987 LogicalResult spirv::LoopOp::verifyRegions() {
2988 auto *op = getOperation();
2989
2990 // We need to verify that the blocks follow the following layout:
2991 //
2992 // +-------------+
2993 // | entry block |
2994 // +-------------+
2995 // |
2996 // v
2997 // +-------------+
2998 // | loop header | <-----+
2999 // +-------------+ |
3000 // |
3001 // ... |
3002 // \ | / |
3003 // v |
3004 // +---------------+ |
3005 // | loop continue | -----+
3006 // +---------------+
3007 //
3008 // ...
3009 // \ | /
3010 // v
3011 // +-------------+
3012 // | merge block |
3013 // +-------------+
3014
3015 auto ®ion = op->getRegion(0);
3016 // Allow empty region as a degenerated case, which can come from
3017 // optimizations.
3018 if (region.empty())
3019 return success();
3020
3021 // The last block is the merge block.
3022 Block &merge = region.back();
3023 if (!isMergeBlock(merge))
3024 return emitOpError(
3025 "last block must be the merge block with only one 'spv.mlir.merge' op");
3026
3027 if (std::next(region.begin()) == region.end())
3028 return emitOpError(
3029 "must have an entry block branching to the loop header block");
3030 // The first block is the entry block.
3031 Block &entry = region.front();
3032
3033 if (std::next(region.begin(), 2) == region.end())
3034 return emitOpError(
3035 "must have a loop header block branched from the entry block");
3036 // The second block is the loop header block.
3037 Block &header = *std::next(region.begin(), 1);
3038
3039 if (!hasOneBranchOpTo(entry, header))
3040 return emitOpError(
3041 "entry block must only have one 'spv.Branch' op to the second block");
3042
3043 if (std::next(region.begin(), 3) == region.end())
3044 return emitOpError(
3045 "requires a loop continue block branching to the loop header block");
3046 // The second to last block is the loop continue block.
3047 Block &cont = *std::prev(region.end(), 2);
3048
3049 // Make sure that we have a branch from the loop continue block to the loop
3050 // header block.
3051 if (llvm::none_of(
3052 llvm::seq<unsigned>(0, cont.getNumSuccessors()),
3053 [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
3054 return emitOpError("second to last block must be the loop continue "
3055 "block that branches to the loop header block");
3056
3057 // Make sure that no other blocks (except the entry and loop continue block)
3058 // branches to the loop header block.
3059 for (auto &block : llvm::make_range(std::next(region.begin(), 2),
3060 std::prev(region.end(), 2))) {
3061 for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
3062 if (block.getSuccessor(i) == &header) {
3063 return emitOpError("can only have the entry and loop continue "
3064 "block branching to the loop header block");
3065 }
3066 }
3067 }
3068
3069 return success();
3070 }
3071
getEntryBlock()3072 Block *spirv::LoopOp::getEntryBlock() {
3073 assert(!body().empty() && "op region should not be empty!");
3074 return &body().front();
3075 }
3076
getHeaderBlock()3077 Block *spirv::LoopOp::getHeaderBlock() {
3078 assert(!body().empty() && "op region should not be empty!");
3079 // The second block is the loop header block.
3080 return &*std::next(body().begin());
3081 }
3082
getContinueBlock()3083 Block *spirv::LoopOp::getContinueBlock() {
3084 assert(!body().empty() && "op region should not be empty!");
3085 // The second to last block is the loop continue block.
3086 return &*std::prev(body().end(), 2);
3087 }
3088
getMergeBlock()3089 Block *spirv::LoopOp::getMergeBlock() {
3090 assert(!body().empty() && "op region should not be empty!");
3091 // The last block is the loop merge block.
3092 return &body().back();
3093 }
3094
addEntryAndMergeBlock()3095 void spirv::LoopOp::addEntryAndMergeBlock() {
3096 assert(body().empty() && "entry and merge block already exist");
3097 body().push_back(new Block());
3098 auto *mergeBlock = new Block();
3099 body().push_back(mergeBlock);
3100 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3101
3102 // Add a spv.mlir.merge op into the merge block.
3103 builder.create<spirv::MergeOp>(getLoc());
3104 }
3105
3106 //===----------------------------------------------------------------------===//
3107 // spv.MemoryBarrierOp
3108 //===----------------------------------------------------------------------===//
3109
verify()3110 LogicalResult spirv::MemoryBarrierOp::verify() {
3111 return verifyMemorySemantics(getOperation(), memory_semantics());
3112 }
3113
3114 //===----------------------------------------------------------------------===//
3115 // spv.mlir.merge
3116 //===----------------------------------------------------------------------===//
3117
verify()3118 LogicalResult spirv::MergeOp::verify() {
3119 auto *parentOp = (*this)->getParentOp();
3120 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
3121 return emitOpError(
3122 "expected parent op to be 'spv.mlir.selection' or 'spv.mlir.loop'");
3123
3124 // TODO: This check should be done in `verifyRegions` of parent op.
3125 Block &parentLastBlock = (*this)->getParentRegion()->back();
3126 if (getOperation() != parentLastBlock.getTerminator())
3127 return emitOpError("can only be used in the last block of "
3128 "'spv.mlir.selection' or 'spv.mlir.loop'");
3129 return success();
3130 }
3131
3132 //===----------------------------------------------------------------------===//
3133 // spv.module
3134 //===----------------------------------------------------------------------===//
3135
build(OpBuilder & builder,OperationState & state,Optional<StringRef> name)3136 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3137 Optional<StringRef> name) {
3138 OpBuilder::InsertionGuard guard(builder);
3139 builder.createBlock(state.addRegion());
3140 if (name) {
3141 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
3142 builder.getStringAttr(*name));
3143 }
3144 }
3145
build(OpBuilder & builder,OperationState & state,spirv::AddressingModel addressingModel,spirv::MemoryModel memoryModel,Optional<VerCapExtAttr> vceTriple,Optional<StringRef> name)3146 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3147 spirv::AddressingModel addressingModel,
3148 spirv::MemoryModel memoryModel,
3149 Optional<VerCapExtAttr> vceTriple,
3150 Optional<StringRef> name) {
3151 state.addAttribute(
3152 "addressing_model",
3153 builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
3154 state.addAttribute("memory_model", builder.getI32IntegerAttr(
3155 static_cast<int32_t>(memoryModel)));
3156 OpBuilder::InsertionGuard guard(builder);
3157 builder.createBlock(state.addRegion());
3158 if (vceTriple)
3159 state.addAttribute(getVCETripleAttrName(), *vceTriple);
3160 if (name)
3161 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
3162 builder.getStringAttr(*name));
3163 }
3164
parse(OpAsmParser & parser,OperationState & state)3165 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser, OperationState &state) {
3166 Region *body = state.addRegion();
3167
3168 // If the name is present, parse it.
3169 StringAttr nameAttr;
3170 (void)parser.parseOptionalSymbolName(
3171 nameAttr, mlir::SymbolTable::getSymbolAttrName(), state.attributes);
3172
3173 // Parse attributes
3174 spirv::AddressingModel addrModel;
3175 spirv::MemoryModel memoryModel;
3176 if (::parseEnumKeywordAttr(addrModel, parser, state) ||
3177 ::parseEnumKeywordAttr(memoryModel, parser, state))
3178 return failure();
3179
3180 if (succeeded(parser.parseOptionalKeyword("requires"))) {
3181 spirv::VerCapExtAttr vceTriple;
3182 if (parser.parseAttribute(vceTriple,
3183 spirv::ModuleOp::getVCETripleAttrName(),
3184 state.attributes))
3185 return failure();
3186 }
3187
3188 if (parser.parseOptionalAttrDictWithKeyword(state.attributes) ||
3189 parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
3190 return failure();
3191
3192 // Make sure we have at least one block.
3193 if (body->empty())
3194 body->push_back(new Block());
3195
3196 return success();
3197 }
3198
print(OpAsmPrinter & printer)3199 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
3200 if (Optional<StringRef> name = getName()) {
3201 printer << ' ';
3202 printer.printSymbolName(*name);
3203 }
3204
3205 SmallVector<StringRef, 2> elidedAttrs;
3206
3207 printer << " " << spirv::stringifyAddressingModel(addressing_model()) << " "
3208 << spirv::stringifyMemoryModel(memory_model());
3209 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
3210 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
3211 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
3212 mlir::SymbolTable::getSymbolAttrName()});
3213
3214 if (Optional<spirv::VerCapExtAttr> triple = vce_triple()) {
3215 printer << " requires " << *triple;
3216 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
3217 }
3218
3219 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
3220 printer << ' ';
3221 printer.printRegion(getRegion());
3222 }
3223
verifyRegions()3224 LogicalResult spirv::ModuleOp::verifyRegions() {
3225 Dialect *dialect = (*this)->getDialect();
3226 DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
3227 entryPoints;
3228 mlir::SymbolTable table(*this);
3229
3230 for (auto &op : *getBody()) {
3231 if (op.getDialect() != dialect)
3232 return op.emitError("'spv.module' can only contain spv.* ops");
3233
3234 // For EntryPoint op, check that the function and execution model is not
3235 // duplicated in EntryPointOps. Also verify that the interface specified
3236 // comes from globalVariables here to make this check cheaper.
3237 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
3238 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
3239 if (!funcOp) {
3240 return entryPointOp.emitError("function '")
3241 << entryPointOp.fn() << "' not found in 'spv.module'";
3242 }
3243 if (auto interface = entryPointOp.interface()) {
3244 for (Attribute varRef : interface) {
3245 auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
3246 if (!varSymRef) {
3247 return entryPointOp.emitError(
3248 "expected symbol reference for interface "
3249 "specification instead of '")
3250 << varRef;
3251 }
3252 auto variableOp =
3253 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
3254 if (!variableOp) {
3255 return entryPointOp.emitError("expected spv.GlobalVariable "
3256 "symbol reference instead of'")
3257 << varSymRef << "'";
3258 }
3259 }
3260 }
3261
3262 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
3263 funcOp, entryPointOp.execution_model());
3264 auto entryPtIt = entryPoints.find(key);
3265 if (entryPtIt != entryPoints.end()) {
3266 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
3267 }
3268 entryPoints[key] = entryPointOp;
3269 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
3270 if (funcOp.isExternal())
3271 return op.emitError("'spv.module' cannot contain external functions");
3272
3273 // TODO: move this check to spv.func.
3274 for (auto &block : funcOp)
3275 for (auto &op : block) {
3276 if (op.getDialect() != dialect)
3277 return op.emitError(
3278 "functions in 'spv.module' can only contain spv.* ops");
3279 }
3280 }
3281 }
3282
3283 return success();
3284 }
3285
3286 //===----------------------------------------------------------------------===//
3287 // spv.mlir.referenceof
3288 //===----------------------------------------------------------------------===//
3289
verify()3290 LogicalResult spirv::ReferenceOfOp::verify() {
3291 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
3292 (*this)->getParentOp(), spec_constAttr());
3293 Type constType;
3294
3295 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
3296 if (specConstOp)
3297 constType = specConstOp.default_value().getType();
3298
3299 auto specConstCompositeOp =
3300 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
3301 if (specConstCompositeOp)
3302 constType = specConstCompositeOp.type();
3303
3304 if (!specConstOp && !specConstCompositeOp)
3305 return emitOpError(
3306 "expected spv.SpecConstant or spv.SpecConstantComposite symbol");
3307
3308 if (reference().getType() != constType)
3309 return emitOpError("result type mismatch with the referenced "
3310 "specialization constant's type");
3311
3312 return success();
3313 }
3314
3315 //===----------------------------------------------------------------------===//
3316 // spv.Return
3317 //===----------------------------------------------------------------------===//
3318
verify()3319 LogicalResult spirv::ReturnOp::verify() {
3320 // Verification is performed in spv.func op.
3321 return success();
3322 }
3323
3324 //===----------------------------------------------------------------------===//
3325 // spv.ReturnValue
3326 //===----------------------------------------------------------------------===//
3327
verify()3328 LogicalResult spirv::ReturnValueOp::verify() {
3329 // Verification is performed in spv.func op.
3330 return success();
3331 }
3332
3333 //===----------------------------------------------------------------------===//
3334 // spv.Select
3335 //===----------------------------------------------------------------------===//
3336
verify()3337 LogicalResult spirv::SelectOp::verify() {
3338 if (auto conditionTy = condition().getType().dyn_cast<VectorType>()) {
3339 auto resultVectorTy = result().getType().dyn_cast<VectorType>();
3340 if (!resultVectorTy) {
3341 return emitOpError("result expected to be of vector type when "
3342 "condition is of vector type");
3343 }
3344 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
3345 return emitOpError("result should have the same number of elements as "
3346 "the condition when condition is of vector type");
3347 }
3348 }
3349 return success();
3350 }
3351
3352 //===----------------------------------------------------------------------===//
3353 // spv.mlir.selection
3354 //===----------------------------------------------------------------------===//
3355
parse(OpAsmParser & parser,OperationState & state)3356 ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
3357 OperationState &state) {
3358 if (parseControlAttribute<spirv::SelectionControl>(parser, state))
3359 return failure();
3360 return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
3361 /*argTypes=*/{});
3362 }
3363
print(OpAsmPrinter & printer)3364 void spirv::SelectionOp::print(OpAsmPrinter &printer) {
3365 auto control = selection_control();
3366 if (control != spirv::SelectionControl::None)
3367 printer << " control(" << spirv::stringifySelectionControl(control) << ")";
3368 printer << ' ';
3369 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3370 /*printBlockTerminators=*/true);
3371 }
3372
verifyRegions()3373 LogicalResult spirv::SelectionOp::verifyRegions() {
3374 auto *op = getOperation();
3375
3376 // We need to verify that the blocks follow the following layout:
3377 //
3378 // +--------------+
3379 // | header block |
3380 // +--------------+
3381 // / | \
3382 // ...
3383 //
3384 //
3385 // +---------+ +---------+ +---------+
3386 // | case #0 | | case #1 | | case #2 | ...
3387 // +---------+ +---------+ +---------+
3388 //
3389 //
3390 // ...
3391 // \ | /
3392 // v
3393 // +-------------+
3394 // | merge block |
3395 // +-------------+
3396
3397 auto ®ion = op->getRegion(0);
3398 // Allow empty region as a degenerated case, which can come from
3399 // optimizations.
3400 if (region.empty())
3401 return success();
3402
3403 // The last block is the merge block.
3404 if (!isMergeBlock(region.back()))
3405 return emitOpError(
3406 "last block must be the merge block with only one 'spv.mlir.merge' op");
3407
3408 if (std::next(region.begin()) == region.end())
3409 return emitOpError("must have a selection header block");
3410
3411 return success();
3412 }
3413
getHeaderBlock()3414 Block *spirv::SelectionOp::getHeaderBlock() {
3415 assert(!body().empty() && "op region should not be empty!");
3416 // The first block is the loop header block.
3417 return &body().front();
3418 }
3419
getMergeBlock()3420 Block *spirv::SelectionOp::getMergeBlock() {
3421 assert(!body().empty() && "op region should not be empty!");
3422 // The last block is the loop merge block.
3423 return &body().back();
3424 }
3425
addMergeBlock()3426 void spirv::SelectionOp::addMergeBlock() {
3427 assert(body().empty() && "entry and merge block already exist");
3428 auto *mergeBlock = new Block();
3429 body().push_back(mergeBlock);
3430 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3431
3432 // Add a spv.mlir.merge op into the merge block.
3433 builder.create<spirv::MergeOp>(getLoc());
3434 }
3435
createIfThen(Location loc,Value condition,function_ref<void (OpBuilder & builder)> thenBody,OpBuilder & builder)3436 spirv::SelectionOp spirv::SelectionOp::createIfThen(
3437 Location loc, Value condition,
3438 function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
3439 auto selectionOp =
3440 builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
3441
3442 selectionOp.addMergeBlock();
3443 Block *mergeBlock = selectionOp.getMergeBlock();
3444 Block *thenBlock = nullptr;
3445
3446 // Build the "then" block.
3447 {
3448 OpBuilder::InsertionGuard guard(builder);
3449 thenBlock = builder.createBlock(mergeBlock);
3450 thenBody(builder);
3451 builder.create<spirv::BranchOp>(loc, mergeBlock);
3452 }
3453
3454 // Build the header block.
3455 {
3456 OpBuilder::InsertionGuard guard(builder);
3457 builder.createBlock(thenBlock);
3458 builder.create<spirv::BranchConditionalOp>(
3459 loc, condition, thenBlock,
3460 /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
3461 /*falseArguments=*/ArrayRef<Value>());
3462 }
3463
3464 return selectionOp;
3465 }
3466
3467 //===----------------------------------------------------------------------===//
3468 // spv.SpecConstant
3469 //===----------------------------------------------------------------------===//
3470
parse(OpAsmParser & parser,OperationState & state)3471 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
3472 OperationState &state) {
3473 StringAttr nameAttr;
3474 Attribute valueAttr;
3475
3476 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3477 state.attributes))
3478 return failure();
3479
3480 // Parse optional spec_id.
3481 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
3482 IntegerAttr specIdAttr;
3483 if (parser.parseLParen() ||
3484 parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
3485 parser.parseRParen())
3486 return failure();
3487 }
3488
3489 if (parser.parseEqual() ||
3490 parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
3491 return failure();
3492
3493 return success();
3494 }
3495
print(OpAsmPrinter & printer)3496 void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
3497 printer << ' ';
3498 printer.printSymbolName(sym_name());
3499 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3500 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
3501 printer << " = " << default_value();
3502 }
3503
verify()3504 LogicalResult spirv::SpecConstantOp::verify() {
3505 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3506 if (specID.getValue().isNegative())
3507 return emitOpError("SpecId cannot be negative");
3508
3509 auto value = default_value();
3510 if (value.isa<IntegerAttr, FloatAttr>()) {
3511 // Make sure bitwidth is allowed.
3512 if (!value.getType().isa<spirv::SPIRVType>())
3513 return emitOpError("default value bitwidth disallowed");
3514 return success();
3515 }
3516 return emitOpError(
3517 "default value can only be a bool, integer, or float scalar");
3518 }
3519
3520 //===----------------------------------------------------------------------===//
3521 // spv.StoreOp
3522 //===----------------------------------------------------------------------===//
3523
parse(OpAsmParser & parser,OperationState & state)3524 ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &state) {
3525 // Parse the storage class specification
3526 spirv::StorageClass storageClass;
3527 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
3528 auto loc = parser.getCurrentLocation();
3529 Type elementType;
3530 if (parseEnumStrAttr(storageClass, parser) ||
3531 parser.parseOperandList(operandInfo, 2) ||
3532 parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3533 parser.parseType(elementType)) {
3534 return failure();
3535 }
3536
3537 auto ptrType = spirv::PointerType::get(elementType, storageClass);
3538 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
3539 state.operands)) {
3540 return failure();
3541 }
3542 return success();
3543 }
3544
print(OpAsmPrinter & printer)3545 void spirv::StoreOp::print(OpAsmPrinter &printer) {
3546 SmallVector<StringRef, 4> elidedAttrs;
3547 StringRef sc = stringifyStorageClass(
3548 ptr().getType().cast<spirv::PointerType>().getStorageClass());
3549 printer << " \"" << sc << "\" " << ptr() << ", " << value();
3550
3551 printMemoryAccessAttribute(*this, printer, elidedAttrs);
3552
3553 printer << " : " << value().getType();
3554 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3555 }
3556
verify()3557 LogicalResult spirv::StoreOp::verify() {
3558 // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
3559 // OpTypePointer whose Type operand is the same as the type of Object."
3560 if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value())))
3561 return failure();
3562 return verifyMemoryAccessAttribute(*this);
3563 }
3564
3565 //===----------------------------------------------------------------------===//
3566 // spv.Unreachable
3567 //===----------------------------------------------------------------------===//
3568
verify()3569 LogicalResult spirv::UnreachableOp::verify() {
3570 auto *block = (*this)->getBlock();
3571 // Fast track: if this is in entry block, its invalid. Otherwise, if no
3572 // predecessors, it's valid.
3573 if (block->isEntryBlock())
3574 return emitOpError("cannot be used in reachable block");
3575 if (block->hasNoPredecessors())
3576 return success();
3577
3578 // TODO: further verification needs to analyze reachability from
3579 // the entry block.
3580
3581 return success();
3582 }
3583
3584 //===----------------------------------------------------------------------===//
3585 // spv.Variable
3586 //===----------------------------------------------------------------------===//
3587
parse(OpAsmParser & parser,OperationState & state)3588 ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
3589 OperationState &state) {
3590 // Parse optional initializer
3591 Optional<OpAsmParser::UnresolvedOperand> initInfo;
3592 if (succeeded(parser.parseOptionalKeyword("init"))) {
3593 initInfo = OpAsmParser::UnresolvedOperand();
3594 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3595 parser.parseRParen())
3596 return failure();
3597 }
3598
3599 if (parseVariableDecorations(parser, state)) {
3600 return failure();
3601 }
3602
3603 // Parse result pointer type
3604 Type type;
3605 if (parser.parseColon())
3606 return failure();
3607 auto loc = parser.getCurrentLocation();
3608 if (parser.parseType(type))
3609 return failure();
3610
3611 auto ptrType = type.dyn_cast<spirv::PointerType>();
3612 if (!ptrType)
3613 return parser.emitError(loc, "expected spv.ptr type");
3614 state.addTypes(ptrType);
3615
3616 // Resolve the initializer operand
3617 if (initInfo) {
3618 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3619 state.operands))
3620 return failure();
3621 }
3622
3623 auto attr = parser.getBuilder().getI32IntegerAttr(
3624 llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
3625 state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3626
3627 return success();
3628 }
3629
print(OpAsmPrinter & printer)3630 void spirv::VariableOp::print(OpAsmPrinter &printer) {
3631 SmallVector<StringRef, 4> elidedAttrs{
3632 spirv::attributeName<spirv::StorageClass>()};
3633 // Print optional initializer
3634 if (getNumOperands() != 0)
3635 printer << " init(" << initializer() << ")";
3636
3637 printVariableDecorations(*this, printer, elidedAttrs);
3638 printer << " : " << getType();
3639 }
3640
verify()3641 LogicalResult spirv::VariableOp::verify() {
3642 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3643 // object. It cannot be Generic. It must be the same as the Storage Class
3644 // operand of the Result Type."
3645 if (storage_class() != spirv::StorageClass::Function) {
3646 return emitOpError(
3647 "can only be used to model function-level variables. Use "
3648 "spv.GlobalVariable for module-level variables.");
3649 }
3650
3651 auto pointerType = pointer().getType().cast<spirv::PointerType>();
3652 if (storage_class() != pointerType.getStorageClass())
3653 return emitOpError(
3654 "storage class must match result pointer's storage class");
3655
3656 if (getNumOperands() != 0) {
3657 // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3658 // a global (module scope) OpVariable instruction".
3659 auto *initOp = getOperand(0).getDefiningOp();
3660 if (!initOp || !isa<spirv::ConstantOp, // for normal constant
3661 spirv::ReferenceOfOp, // for spec constant
3662 spirv::AddressOfOp>(initOp))
3663 return emitOpError("initializer must be the result of a "
3664 "constant or spv.GlobalVariable op");
3665 }
3666
3667 // TODO: generate these strings using ODS.
3668 auto *op = getOperation();
3669 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3670 stringifyDecoration(spirv::Decoration::DescriptorSet));
3671 auto bindingName = llvm::convertToSnakeFromCamelCase(
3672 stringifyDecoration(spirv::Decoration::Binding));
3673 auto builtInName = llvm::convertToSnakeFromCamelCase(
3674 stringifyDecoration(spirv::Decoration::BuiltIn));
3675
3676 for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3677 if (op->getAttr(attr))
3678 return emitOpError("cannot have '")
3679 << attr << "' attribute (only allowed in spv.GlobalVariable)";
3680 }
3681
3682 return success();
3683 }
3684
3685 //===----------------------------------------------------------------------===//
3686 // spv.VectorShuffle
3687 //===----------------------------------------------------------------------===//
3688
verify()3689 LogicalResult spirv::VectorShuffleOp::verify() {
3690 VectorType resultType = getType().cast<VectorType>();
3691
3692 size_t numResultElements = resultType.getNumElements();
3693 if (numResultElements != components().size())
3694 return emitOpError("result type element count (")
3695 << numResultElements
3696 << ") mismatch with the number of component selectors ("
3697 << components().size() << ")";
3698
3699 size_t totalSrcElements =
3700 vector1().getType().cast<VectorType>().getNumElements() +
3701 vector2().getType().cast<VectorType>().getNumElements();
3702
3703 for (const auto &selector : components().getAsValueRange<IntegerAttr>()) {
3704 uint32_t index = selector.getZExtValue();
3705 if (index >= totalSrcElements &&
3706 index != std::numeric_limits<uint32_t>().max())
3707 return emitOpError("component selector ")
3708 << index << " out of range: expected to be in [0, "
3709 << totalSrcElements << ") or 0xffffffff";
3710 }
3711 return success();
3712 }
3713
3714 //===----------------------------------------------------------------------===//
3715 // spv.CooperativeMatrixLoadNV
3716 //===----------------------------------------------------------------------===//
3717
parse(OpAsmParser & parser,OperationState & state)3718 ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
3719 OperationState &state) {
3720 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
3721 Type strideType = parser.getBuilder().getIntegerType(32);
3722 Type columnMajorType = parser.getBuilder().getIntegerType(1);
3723 Type ptrType;
3724 Type elementType;
3725 if (parser.parseOperandList(operandInfo, 3) ||
3726 parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3727 parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3728 return failure();
3729 }
3730 if (parser.resolveOperands(operandInfo,
3731 {ptrType, strideType, columnMajorType},
3732 parser.getNameLoc(), state.operands)) {
3733 return failure();
3734 }
3735
3736 state.addTypes(elementType);
3737 return success();
3738 }
3739
print(OpAsmPrinter & printer)3740 void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) {
3741 printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
3742 // Print optional memory access attribute.
3743 if (auto memAccess = memory_access())
3744 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3745 printer << " : " << pointer().getType() << " as " << getType();
3746 }
3747
verifyPointerAndCoopMatrixType(Operation * op,Type pointer,Type coopMatrix)3748 static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
3749 Type coopMatrix) {
3750 Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3751 if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3752 return op->emitError(
3753 "Pointer must point to a scalar or vector type but provided ")
3754 << pointeeType;
3755 spirv::StorageClass storage =
3756 pointer.cast<spirv::PointerType>().getStorageClass();
3757 if (storage != spirv::StorageClass::Workgroup &&
3758 storage != spirv::StorageClass::StorageBuffer &&
3759 storage != spirv::StorageClass::PhysicalStorageBuffer)
3760 return op->emitError(
3761 "Pointer storage class must be Workgroup, StorageBuffer or "
3762 "PhysicalStorageBufferEXT but provided ")
3763 << stringifyStorageClass(storage);
3764 return success();
3765 }
3766
verify()3767 LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
3768 return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
3769 result().getType());
3770 }
3771
3772 //===----------------------------------------------------------------------===//
3773 // spv.CooperativeMatrixStoreNV
3774 //===----------------------------------------------------------------------===//
3775
parse(OpAsmParser & parser,OperationState & state)3776 ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
3777 OperationState &state) {
3778 SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
3779 Type strideType = parser.getBuilder().getIntegerType(32);
3780 Type columnMajorType = parser.getBuilder().getIntegerType(1);
3781 Type ptrType;
3782 Type elementType;
3783 if (parser.parseOperandList(operandInfo, 4) ||
3784 parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3785 parser.parseType(ptrType) || parser.parseComma() ||
3786 parser.parseType(elementType)) {
3787 return failure();
3788 }
3789 if (parser.resolveOperands(
3790 operandInfo, {ptrType, elementType, strideType, columnMajorType},
3791 parser.getNameLoc(), state.operands)) {
3792 return failure();
3793 }
3794
3795 return success();
3796 }
3797
print(OpAsmPrinter & printer)3798 void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
3799 printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
3800 << columnmajor();
3801 // Print optional memory access attribute.
3802 if (auto memAccess = memory_access())
3803 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3804 printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
3805 }
3806
verify()3807 LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
3808 return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
3809 object().getType());
3810 }
3811
3812 //===----------------------------------------------------------------------===//
3813 // spv.CooperativeMatrixMulAddNV
3814 //===----------------------------------------------------------------------===//
3815
3816 static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op)3817 verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
3818 if (op.c().getType() != op.result().getType())
3819 return op.emitOpError("result and third operand must have the same type");
3820 auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
3821 auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
3822 auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
3823 auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
3824 if (typeA.getRows() != typeR.getRows() ||
3825 typeA.getColumns() != typeB.getRows() ||
3826 typeB.getColumns() != typeR.getColumns())
3827 return op.emitOpError("matrix size must match");
3828 if (typeR.getScope() != typeA.getScope() ||
3829 typeR.getScope() != typeB.getScope() ||
3830 typeR.getScope() != typeC.getScope())
3831 return op.emitOpError("matrix scope must match");
3832 if (typeA.getElementType() != typeB.getElementType() ||
3833 typeR.getElementType() != typeC.getElementType())
3834 return op.emitOpError("matrix element type must match");
3835 return success();
3836 }
3837
verify()3838 LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
3839 return verifyCoopMatrixMulAdd(*this);
3840 }
3841
3842 //===----------------------------------------------------------------------===//
3843 // spv.MatrixTimesScalar
3844 //===----------------------------------------------------------------------===//
3845
verify()3846 LogicalResult spirv::MatrixTimesScalarOp::verify() {
3847 // We already checked that result and matrix are both of matrix type in the
3848 // auto-generated verify method.
3849
3850 auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
3851 auto resultMatrix = result().getType().cast<spirv::MatrixType>();
3852
3853 // Check that the scalar type is the same as the matrix element type.
3854 if (scalar().getType() != inputMatrix.getElementType())
3855 return emitError("input matrix components' type and scaling value must "
3856 "have the same type");
3857
3858 // Note that the next three checks could be done using the AllTypesMatch
3859 // trait in the Op definition file but it generates a vague error message.
3860
3861 // Check that the input and result matrices have the same columns' count
3862 if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
3863 return emitError("input and result matrices must have the same "
3864 "number of columns");
3865
3866 // Check that the input and result matrices' have the same rows count
3867 if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
3868 return emitError("input and result matrices' columns must have "
3869 "the same size");
3870
3871 // Check that the input and result matrices' have the same component type
3872 if (inputMatrix.getElementType() != resultMatrix.getElementType())
3873 return emitError("input and result matrices' columns must have "
3874 "the same component type");
3875
3876 return success();
3877 }
3878
3879 //===----------------------------------------------------------------------===//
3880 // spv.CopyMemory
3881 //===----------------------------------------------------------------------===//
3882
print(OpAsmPrinter & printer)3883 void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
3884 printer << ' ';
3885
3886 StringRef targetStorageClass = stringifyStorageClass(
3887 target().getType().cast<spirv::PointerType>().getStorageClass());
3888 printer << " \"" << targetStorageClass << "\" " << target() << ", ";
3889
3890 StringRef sourceStorageClass = stringifyStorageClass(
3891 source().getType().cast<spirv::PointerType>().getStorageClass());
3892 printer << " \"" << sourceStorageClass << "\" " << source();
3893
3894 SmallVector<StringRef, 4> elidedAttrs;
3895 printMemoryAccessAttribute(*this, printer, elidedAttrs);
3896 printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
3897 source_memory_access(), source_alignment());
3898
3899 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3900
3901 Type pointeeType =
3902 target().getType().cast<spirv::PointerType>().getPointeeType();
3903 printer << " : " << pointeeType;
3904 }
3905
parse(OpAsmParser & parser,OperationState & state)3906 ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
3907 OperationState &state) {
3908 spirv::StorageClass targetStorageClass;
3909 OpAsmParser::UnresolvedOperand targetPtrInfo;
3910
3911 spirv::StorageClass sourceStorageClass;
3912 OpAsmParser::UnresolvedOperand sourcePtrInfo;
3913
3914 Type elementType;
3915
3916 if (parseEnumStrAttr(targetStorageClass, parser) ||
3917 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
3918 parseEnumStrAttr(sourceStorageClass, parser) ||
3919 parser.parseOperand(sourcePtrInfo) ||
3920 parseMemoryAccessAttributes(parser, state)) {
3921 return failure();
3922 }
3923
3924 if (!parser.parseOptionalComma()) {
3925 // Parse 2nd memory access attributes.
3926 if (parseSourceMemoryAccessAttributes(parser, state)) {
3927 return failure();
3928 }
3929 }
3930
3931 if (parser.parseColon() || parser.parseType(elementType))
3932 return failure();
3933
3934 if (parser.parseOptionalAttrDict(state.attributes))
3935 return failure();
3936
3937 auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
3938 auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
3939
3940 if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
3941 parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
3942 return failure();
3943 }
3944
3945 return success();
3946 }
3947
verify()3948 LogicalResult spirv::CopyMemoryOp::verify() {
3949 Type targetType =
3950 target().getType().cast<spirv::PointerType>().getPointeeType();
3951
3952 Type sourceType =
3953 source().getType().cast<spirv::PointerType>().getPointeeType();
3954
3955 if (targetType != sourceType)
3956 return emitOpError("both operands must be pointers to the same type");
3957
3958 if (failed(verifyMemoryAccessAttribute(*this)))
3959 return failure();
3960
3961 // TODO - According to the spec:
3962 //
3963 // If two masks are present, the first applies to Target and cannot include
3964 // MakePointerVisible, and the second applies to Source and cannot include
3965 // MakePointerAvailable.
3966 //
3967 // Add such verification here.
3968
3969 return verifySourceMemoryAccessAttribute(*this);
3970 }
3971
3972 //===----------------------------------------------------------------------===//
3973 // spv.Transpose
3974 //===----------------------------------------------------------------------===//
3975
verify()3976 LogicalResult spirv::TransposeOp::verify() {
3977 auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
3978 auto resultMatrix = result().getType().cast<spirv::MatrixType>();
3979
3980 // Verify that the input and output matrices have correct shapes.
3981 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
3982 return emitError("input matrix rows count must be equal to "
3983 "output matrix columns count");
3984
3985 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
3986 return emitError("input matrix columns count must be equal to "
3987 "output matrix rows count");
3988
3989 // Verify that the input and output matrices have the same component type
3990 if (inputMatrix.getElementType() != resultMatrix.getElementType())
3991 return emitError("input and output matrices must have the same "
3992 "component type");
3993
3994 return success();
3995 }
3996
3997 //===----------------------------------------------------------------------===//
3998 // spv.MatrixTimesMatrix
3999 //===----------------------------------------------------------------------===//
4000
verify()4001 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
4002 auto leftMatrix = leftmatrix().getType().cast<spirv::MatrixType>();
4003 auto rightMatrix = rightmatrix().getType().cast<spirv::MatrixType>();
4004 auto resultMatrix = result().getType().cast<spirv::MatrixType>();
4005
4006 // left matrix columns' count and right matrix rows' count must be equal
4007 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
4008 return emitError("left matrix columns' count must be equal to "
4009 "the right matrix rows' count");
4010
4011 // right and result matrices columns' count must be the same
4012 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
4013 return emitError(
4014 "right and result matrices must have equal columns' count");
4015
4016 // right and result matrices component type must be the same
4017 if (rightMatrix.getElementType() != resultMatrix.getElementType())
4018 return emitError("right and result matrices' component type must"
4019 " be the same");
4020
4021 // left and result matrices component type must be the same
4022 if (leftMatrix.getElementType() != resultMatrix.getElementType())
4023 return emitError("left and result matrices' component type"
4024 " must be the same");
4025
4026 // left and result matrices rows count must be the same
4027 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
4028 return emitError("left and result matrices must have equal rows' count");
4029
4030 return success();
4031 }
4032
4033 //===----------------------------------------------------------------------===//
4034 // spv.SpecConstantComposite
4035 //===----------------------------------------------------------------------===//
4036
parse(OpAsmParser & parser,OperationState & state)4037 ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
4038 OperationState &state) {
4039
4040 StringAttr compositeName;
4041 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
4042 state.attributes))
4043 return failure();
4044
4045 if (parser.parseLParen())
4046 return failure();
4047
4048 SmallVector<Attribute, 4> constituents;
4049
4050 do {
4051 // The name of the constituent attribute isn't important
4052 const char *attrName = "spec_const";
4053 FlatSymbolRefAttr specConstRef;
4054 NamedAttrList attrs;
4055
4056 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
4057 return failure();
4058
4059 constituents.push_back(specConstRef);
4060 } while (!parser.parseOptionalComma());
4061
4062 if (parser.parseRParen())
4063 return failure();
4064
4065 state.addAttribute(kCompositeSpecConstituentsName,
4066 parser.getBuilder().getArrayAttr(constituents));
4067
4068 Type type;
4069 if (parser.parseColonType(type))
4070 return failure();
4071
4072 state.addAttribute(kTypeAttrName, TypeAttr::get(type));
4073
4074 return success();
4075 }
4076
print(OpAsmPrinter & printer)4077 void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
4078 printer << " ";
4079 printer.printSymbolName(sym_name());
4080 printer << " (";
4081 auto constituents = this->constituents().getValue();
4082
4083 if (!constituents.empty())
4084 llvm::interleaveComma(constituents, printer);
4085
4086 printer << ") : " << type();
4087 }
4088
verify()4089 LogicalResult spirv::SpecConstantCompositeOp::verify() {
4090 auto cType = type().dyn_cast<spirv::CompositeType>();
4091 auto constituents = this->constituents().getValue();
4092
4093 if (!cType)
4094 return emitError("result type must be a composite type, but provided ")
4095 << type();
4096
4097 if (cType.isa<spirv::CooperativeMatrixNVType>())
4098 return emitError("unsupported composite type ") << cType;
4099 if (constituents.size() != cType.getNumElements())
4100 return emitError("has incorrect number of operands: expected ")
4101 << cType.getNumElements() << ", but provided "
4102 << constituents.size();
4103
4104 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
4105 auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
4106
4107 auto constituentSpecConstOp =
4108 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
4109 (*this)->getParentOp(), constituent.getAttr()));
4110
4111 if (constituentSpecConstOp.default_value().getType() !=
4112 cType.getElementType(index))
4113 return emitError("has incorrect types of operands: expected ")
4114 << cType.getElementType(index) << ", but provided "
4115 << constituentSpecConstOp.default_value().getType();
4116 }
4117
4118 return success();
4119 }
4120
4121 //===----------------------------------------------------------------------===//
4122 // spv.SpecConstantOperation
4123 //===----------------------------------------------------------------------===//
4124
parse(OpAsmParser & parser,OperationState & state)4125 ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
4126 OperationState &state) {
4127 Region *body = state.addRegion();
4128
4129 if (parser.parseKeyword("wraps"))
4130 return failure();
4131
4132 body->push_back(new Block);
4133 Block &block = body->back();
4134 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
4135
4136 if (!wrappedOp)
4137 return failure();
4138
4139 OpBuilder builder(parser.getContext());
4140 builder.setInsertionPointToEnd(&block);
4141 builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
4142 state.location = wrappedOp->getLoc();
4143
4144 state.addTypes(wrappedOp->getResult(0).getType());
4145
4146 if (parser.parseOptionalAttrDict(state.attributes))
4147 return failure();
4148
4149 return success();
4150 }
4151
print(OpAsmPrinter & printer)4152 void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
4153 printer << " wraps ";
4154 printer.printGenericOp(&body().front().front());
4155 }
4156
verifyRegions()4157 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
4158 Block &block = getRegion().getBlocks().front();
4159
4160 if (block.getOperations().size() != 2)
4161 return emitOpError("expected exactly 2 nested ops");
4162
4163 Operation &enclosedOp = block.getOperations().front();
4164
4165 if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
4166 return emitOpError("invalid enclosed op");
4167
4168 for (auto operand : enclosedOp.getOperands())
4169 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
4170 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
4171 return emitOpError(
4172 "invalid operand, must be defined by a constant operation");
4173
4174 return success();
4175 }
4176
4177 //===----------------------------------------------------------------------===//
4178 // spv.GL.FrexpStruct
4179 //===----------------------------------------------------------------------===//
4180
verify()4181 LogicalResult spirv::GLFrexpStructOp::verify() {
4182 spirv::StructType structTy = result().getType().dyn_cast<spirv::StructType>();
4183
4184 if (structTy.getNumElements() != 2)
4185 return emitError("result type must be a struct type with two memebers");
4186
4187 Type significandTy = structTy.getElementType(0);
4188 Type exponentTy = structTy.getElementType(1);
4189 VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
4190 IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
4191
4192 Type operandTy = operand().getType();
4193 VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
4194 FloatType operandFTy = operandTy.dyn_cast<FloatType>();
4195
4196 if (significandTy != operandTy)
4197 return emitError("member zero of the resulting struct type must be the "
4198 "same type as the operand");
4199
4200 if (exponentVecTy) {
4201 IntegerType componentIntTy =
4202 exponentVecTy.getElementType().dyn_cast<IntegerType>();
4203 if (!componentIntTy || componentIntTy.getWidth() != 32)
4204 return emitError("member one of the resulting struct type must"
4205 "be a scalar or vector of 32 bit integer type");
4206 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
4207 return emitError("member one of the resulting struct type "
4208 "must be a scalar or vector of 32 bit integer type");
4209 }
4210
4211 // Check that the two member types have the same number of components
4212 if (operandVecTy && exponentVecTy &&
4213 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
4214 return success();
4215
4216 if (operandFTy && exponentIntTy)
4217 return success();
4218
4219 return emitError("member one of the resulting struct type must have the same "
4220 "number of components as the operand type");
4221 }
4222
4223 //===----------------------------------------------------------------------===//
4224 // spv.GL.Ldexp
4225 //===----------------------------------------------------------------------===//
4226
verify()4227 LogicalResult spirv::GLLdexpOp::verify() {
4228 Type significandType = x().getType();
4229 Type exponentType = exp().getType();
4230
4231 if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
4232 return emitOpError("operands must both be scalars or vectors");
4233
4234 auto getNumElements = [](Type type) -> unsigned {
4235 if (auto vectorType = type.dyn_cast<VectorType>())
4236 return vectorType.getNumElements();
4237 return 1;
4238 };
4239
4240 if (getNumElements(significandType) != getNumElements(exponentType))
4241 return emitOpError("operands must have the same number of elements");
4242
4243 return success();
4244 }
4245
4246 //===----------------------------------------------------------------------===//
4247 // spv.ImageDrefGather
4248 //===----------------------------------------------------------------------===//
4249
verify()4250 LogicalResult spirv::ImageDrefGatherOp::verify() {
4251 VectorType resultType = result().getType().cast<VectorType>();
4252 auto sampledImageType =
4253 sampledimage().getType().cast<spirv::SampledImageType>();
4254 auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
4255
4256 if (resultType.getNumElements() != 4)
4257 return emitOpError("result type must be a vector of four components");
4258
4259 Type elementType = resultType.getElementType();
4260 Type sampledElementType = imageType.getElementType();
4261 if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
4262 return emitOpError(
4263 "the component type of result must be the same as sampled type of the "
4264 "underlying image type");
4265
4266 spirv::Dim imageDim = imageType.getDim();
4267 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
4268
4269 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
4270 imageDim != spirv::Dim::Rect)
4271 return emitOpError(
4272 "the Dim operand of the underlying image type must be 2D, Cube, or "
4273 "Rect");
4274
4275 if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
4276 return emitOpError("the MS operand of the underlying image type must be 0");
4277
4278 spirv::ImageOperandsAttr attr = imageoperandsAttr();
4279 auto operandArguments = operand_arguments();
4280
4281 return verifyImageOperands(*this, attr, operandArguments);
4282 }
4283
4284 //===----------------------------------------------------------------------===//
4285 // spv.ShiftLeftLogicalOp
4286 //===----------------------------------------------------------------------===//
4287
verify()4288 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
4289 return verifyShiftOp(*this);
4290 }
4291
4292 //===----------------------------------------------------------------------===//
4293 // spv.ShiftRightArithmeticOp
4294 //===----------------------------------------------------------------------===//
4295
verify()4296 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
4297 return verifyShiftOp(*this);
4298 }
4299
4300 //===----------------------------------------------------------------------===//
4301 // spv.ShiftRightLogicalOp
4302 //===----------------------------------------------------------------------===//
4303
verify()4304 LogicalResult spirv::ShiftRightLogicalOp::verify() {
4305 return verifyShiftOp(*this);
4306 }
4307
4308 //===----------------------------------------------------------------------===//
4309 // spv.ImageQuerySize
4310 //===----------------------------------------------------------------------===//
4311
verify()4312 LogicalResult spirv::ImageQuerySizeOp::verify() {
4313 spirv::ImageType imageType = image().getType().cast<spirv::ImageType>();
4314 Type resultType = result().getType();
4315
4316 spirv::Dim dim = imageType.getDim();
4317 spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
4318 spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
4319 switch (dim) {
4320 case spirv::Dim::Dim1D:
4321 case spirv::Dim::Dim2D:
4322 case spirv::Dim::Dim3D:
4323 case spirv::Dim::Cube:
4324 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
4325 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
4326 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
4327 return emitError(
4328 "if Dim is 1D, 2D, 3D, or Cube, "
4329 "it must also have either an MS of 1 or a Sampled of 0 or 2");
4330 break;
4331 case spirv::Dim::Buffer:
4332 case spirv::Dim::Rect:
4333 break;
4334 default:
4335 return emitError("the Dim operand of the image type must "
4336 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
4337 }
4338
4339 unsigned componentNumber = 0;
4340 switch (dim) {
4341 case spirv::Dim::Dim1D:
4342 case spirv::Dim::Buffer:
4343 componentNumber = 1;
4344 break;
4345 case spirv::Dim::Dim2D:
4346 case spirv::Dim::Cube:
4347 case spirv::Dim::Rect:
4348 componentNumber = 2;
4349 break;
4350 case spirv::Dim::Dim3D:
4351 componentNumber = 3;
4352 break;
4353 default:
4354 break;
4355 }
4356
4357 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
4358 componentNumber += 1;
4359
4360 unsigned resultComponentNumber = 1;
4361 if (auto resultVectorType = resultType.dyn_cast<VectorType>())
4362 resultComponentNumber = resultVectorType.getNumElements();
4363
4364 if (componentNumber != resultComponentNumber)
4365 return emitError("expected the result to have ")
4366 << componentNumber << " component(s), but found "
4367 << resultComponentNumber << " component(s)";
4368
4369 return success();
4370 }
4371
parsePtrAccessChainOpImpl(StringRef opName,OpAsmParser & parser,OperationState & state)4372 static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
4373 OpAsmParser &parser,
4374 OperationState &state) {
4375 OpAsmParser::UnresolvedOperand ptrInfo;
4376 SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
4377 Type type;
4378 auto loc = parser.getCurrentLocation();
4379 SmallVector<Type, 4> indicesTypes;
4380
4381 if (parser.parseOperand(ptrInfo) ||
4382 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
4383 parser.parseColonType(type) ||
4384 parser.resolveOperand(ptrInfo, type, state.operands))
4385 return failure();
4386
4387 // Check that the provided indices list is not empty before parsing their
4388 // type list.
4389 if (indicesInfo.empty())
4390 return emitError(state.location) << opName << " expected element";
4391
4392 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
4393 return failure();
4394
4395 // Check that the indices types list is not empty and that it has a one-to-one
4396 // mapping to the provided indices.
4397 if (indicesTypes.size() != indicesInfo.size())
4398 return emitError(state.location)
4399 << opName
4400 << " indices types' count must be equal to indices info count";
4401
4402 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
4403 return failure();
4404
4405 auto resultType = getElementPtrType(
4406 type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
4407 if (!resultType)
4408 return failure();
4409
4410 state.addTypes(resultType);
4411 return success();
4412 }
4413
4414 template <typename Op>
concatElemAndIndices(Op op)4415 static auto concatElemAndIndices(Op op) {
4416 SmallVector<Value> ret(op.indices().size() + 1);
4417 ret[0] = op.element();
4418 llvm::copy(op.indices(), ret.begin() + 1);
4419 return ret;
4420 }
4421
4422 //===----------------------------------------------------------------------===//
4423 // spv.InBoundsPtrAccessChainOp
4424 //===----------------------------------------------------------------------===//
4425
build(OpBuilder & builder,OperationState & state,Value basePtr,Value element,ValueRange indices)4426 void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
4427 OperationState &state,
4428 Value basePtr, Value element,
4429 ValueRange indices) {
4430 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4431 assert(type && "Unable to deduce return type based on basePtr and indices");
4432 build(builder, state, type, basePtr, element, indices);
4433 }
4434
parse(OpAsmParser & parser,OperationState & state)4435 ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
4436 OperationState &state) {
4437 return parsePtrAccessChainOpImpl(
4438 spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state);
4439 }
4440
print(OpAsmPrinter & printer)4441 void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
4442 printAccessChain(*this, concatElemAndIndices(*this), printer);
4443 }
4444
verify()4445 LogicalResult spirv::InBoundsPtrAccessChainOp::verify() {
4446 return verifyAccessChain(*this, indices());
4447 }
4448
4449 //===----------------------------------------------------------------------===//
4450 // spv.PtrAccessChainOp
4451 //===----------------------------------------------------------------------===//
4452
build(OpBuilder & builder,OperationState & state,Value basePtr,Value element,ValueRange indices)4453 void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
4454 Value basePtr, Value element,
4455 ValueRange indices) {
4456 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4457 assert(type && "Unable to deduce return type based on basePtr and indices");
4458 build(builder, state, type, basePtr, element, indices);
4459 }
4460
parse(OpAsmParser & parser,OperationState & state)4461 ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
4462 OperationState &state) {
4463 return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
4464 parser, state);
4465 }
4466
print(OpAsmPrinter & printer)4467 void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) {
4468 printAccessChain(*this, concatElemAndIndices(*this), printer);
4469 }
4470
verify()4471 LogicalResult spirv::PtrAccessChainOp::verify() {
4472 return verifyAccessChain(*this, indices());
4473 }
4474
4475 //===----------------------------------------------------------------------===//
4476 // spv.VectorTimesScalarOp
4477 //===----------------------------------------------------------------------===//
4478
verify()4479 LogicalResult spirv::VectorTimesScalarOp::verify() {
4480 if (vector().getType() != getType())
4481 return emitOpError("vector operand and result type mismatch");
4482 auto scalarType = getType().cast<VectorType>().getElementType();
4483 if (scalar().getType() != scalarType)
4484 return emitOpError("scalar operand and result element type match");
4485 return success();
4486 }
4487
4488 // TableGen'erated operation interfaces for querying versions, extensions, and
4489 // capabilities.
4490 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
4491
4492 // TablenGen'erated operation definitions.
4493 #define GET_OP_CLASSES
4494 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
4495
4496 namespace mlir {
4497 namespace spirv {
4498 // TableGen'erated operation availability interface implementations.
4499 #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
4500 } // namespace spirv
4501 } // namespace mlir
4502