1 //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the operations in the SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14 
15 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18 #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
19 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
20 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
21 #include "mlir/IR/Builders.h"
22 #include "mlir/IR/BuiltinOps.h"
23 #include "mlir/IR/BuiltinTypes.h"
24 #include "mlir/IR/FunctionImplementation.h"
25 #include "mlir/IR/OpDefinition.h"
26 #include "mlir/IR/OpImplementation.h"
27 #include "mlir/IR/TypeUtilities.h"
28 #include "mlir/Interfaces/CallInterfaces.h"
29 #include "llvm/ADT/APFloat.h"
30 #include "llvm/ADT/APInt.h"
31 #include "llvm/ADT/STLExtras.h"
32 #include "llvm/ADT/StringExtras.h"
33 #include "llvm/ADT/bit.h"
34 
35 using namespace mlir;
36 
37 // TODO: generate these strings using ODS.
38 static constexpr const char kMemoryAccessAttrName[] = "memory_access";
39 static constexpr const char kSourceMemoryAccessAttrName[] =
40     "source_memory_access";
41 static constexpr const char kAlignmentAttrName[] = "alignment";
42 static constexpr const char kSourceAlignmentAttrName[] = "source_alignment";
43 static constexpr const char kBranchWeightAttrName[] = "branch_weights";
44 static constexpr const char kCallee[] = "callee";
45 static constexpr const char kClusterSize[] = "cluster_size";
46 static constexpr const char kControl[] = "control";
47 static constexpr const char kDefaultValueAttrName[] = "default_value";
48 static constexpr const char kExecutionScopeAttrName[] = "execution_scope";
49 static constexpr const char kEqualSemanticsAttrName[] = "equal_semantics";
50 static constexpr const char kFnNameAttrName[] = "fn";
51 static constexpr const char kGroupOperationAttrName[] = "group_operation";
52 static constexpr const char kIndicesAttrName[] = "indices";
53 static constexpr const char kInitializerAttrName[] = "initializer";
54 static constexpr const char kInterfaceAttrName[] = "interface";
55 static constexpr const char kMemoryScopeAttrName[] = "memory_scope";
56 static constexpr const char kSemanticsAttrName[] = "semantics";
57 static constexpr const char kSpecIdAttrName[] = "spec_id";
58 static constexpr const char kTypeAttrName[] = "type";
59 static constexpr const char kUnequalSemanticsAttrName[] = "unequal_semantics";
60 static constexpr const char kValueAttrName[] = "value";
61 static constexpr const char kValuesAttrName[] = "values";
62 static constexpr const char kCompositeSpecConstituentsName[] = "constituents";
63 
64 //===----------------------------------------------------------------------===//
65 // Common utility functions
66 //===----------------------------------------------------------------------===//
67 
parseOneResultSameOperandTypeOp(OpAsmParser & parser,OperationState & result)68 static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
69                                                    OperationState &result) {
70   SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
71   Type type;
72   // If the operand list is in-between parentheses, then we have a generic form.
73   // (see the fallback in `printOneResultOp`).
74   SMLoc loc = parser.getCurrentLocation();
75   if (!parser.parseOptionalLParen()) {
76     if (parser.parseOperandList(ops) || parser.parseRParen() ||
77         parser.parseOptionalAttrDict(result.attributes) ||
78         parser.parseColon() || parser.parseType(type))
79       return failure();
80     auto fnType = type.dyn_cast<FunctionType>();
81     if (!fnType) {
82       parser.emitError(loc, "expected function type");
83       return failure();
84     }
85     if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
86       return failure();
87     result.addTypes(fnType.getResults());
88     return success();
89   }
90   return failure(parser.parseOperandList(ops) ||
91                  parser.parseOptionalAttrDict(result.attributes) ||
92                  parser.parseColonType(type) ||
93                  parser.resolveOperands(ops, type, result.operands) ||
94                  parser.addTypeToList(type, result.types));
95 }
96 
printOneResultOp(Operation * op,OpAsmPrinter & p)97 static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
98   assert(op->getNumResults() == 1 && "op should have one result");
99 
100   // If not all the operand and result types are the same, just use the
101   // generic assembly form to avoid omitting information in printing.
102   auto resultType = op->getResult(0).getType();
103   if (llvm::any_of(op->getOperandTypes(),
104                    [&](Type type) { return type != resultType; })) {
105     p.printGenericOp(op, /*printOpName=*/false);
106     return;
107   }
108 
109   p << ' ';
110   p.printOperands(op->getOperands());
111   p.printOptionalAttrDict(op->getAttrs());
112   // Now we can output only one type for all operands and the result.
113   p << " : " << resultType;
114 }
115 
116 /// Returns true if the given op is a function-like op or nested in a
117 /// function-like op without a module-like op in the middle.
isNestedInFunctionOpInterface(Operation * op)118 static bool isNestedInFunctionOpInterface(Operation *op) {
119   if (!op)
120     return false;
121   if (op->hasTrait<OpTrait::SymbolTable>())
122     return false;
123   if (isa<FunctionOpInterface>(op))
124     return true;
125   return isNestedInFunctionOpInterface(op->getParentOp());
126 }
127 
128 /// Returns true if the given op is an module-like op that maintains a symbol
129 /// table.
isDirectInModuleLikeOp(Operation * op)130 static bool isDirectInModuleLikeOp(Operation *op) {
131   return op && op->hasTrait<OpTrait::SymbolTable>();
132 }
133 
extractValueFromConstOp(Operation * op,int32_t & value)134 static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
135   auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
136   if (!constOp) {
137     return failure();
138   }
139   auto valueAttr = constOp.value();
140   auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
141   if (!integerValueAttr) {
142     return failure();
143   }
144 
145   if (integerValueAttr.getType().isSignlessInteger())
146     value = integerValueAttr.getInt();
147   else
148     value = integerValueAttr.getSInt();
149 
150   return success();
151 }
152 
153 template <typename Ty>
154 static ArrayAttr
getStrArrayAttrForEnumList(Builder & builder,ArrayRef<Ty> enumValues,function_ref<StringRef (Ty)> stringifyFn)155 getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
156                            function_ref<StringRef(Ty)> stringifyFn) {
157   if (enumValues.empty()) {
158     return nullptr;
159   }
160   SmallVector<StringRef, 1> enumValStrs;
161   enumValStrs.reserve(enumValues.size());
162   for (auto val : enumValues) {
163     enumValStrs.emplace_back(stringifyFn(val));
164   }
165   return builder.getStrArrayAttr(enumValStrs);
166 }
167 
168 /// Parses the next string attribute in `parser` as an enumerant of the given
169 /// `EnumClass`.
170 template <typename EnumClass>
171 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,StringRef attrName=spirv::attributeName<EnumClass> ())172 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
173                  StringRef attrName = spirv::attributeName<EnumClass>()) {
174   Attribute attrVal;
175   NamedAttrList attr;
176   auto loc = parser.getCurrentLocation();
177   if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
178                             attrName, attr)) {
179     return failure();
180   }
181   if (!attrVal.isa<StringAttr>()) {
182     return parser.emitError(loc, "expected ")
183            << attrName << " attribute specified as string";
184   }
185   auto attrOptional =
186       spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
187   if (!attrOptional) {
188     return parser.emitError(loc, "invalid ")
189            << attrName << " attribute specification: " << attrVal;
190   }
191   value = *attrOptional;
192   return success();
193 }
194 
195 /// Parses the next string attribute in `parser` as an enumerant of the given
196 /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
197 /// attribute with the enum class's name as attribute name.
198 template <typename EnumClass>
199 static ParseResult
parseEnumStrAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())200 parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
201                  StringRef attrName = spirv::attributeName<EnumClass>()) {
202   if (parseEnumStrAttr(value, parser)) {
203     return failure();
204   }
205   state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
206                                    llvm::bit_cast<int32_t>(value)));
207   return success();
208 }
209 
210 /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
211 /// and inserts the enumerant into `state` as an 32-bit integer attribute with
212 /// the enum class's name as attribute name.
213 template <typename EnumClass>
214 static ParseResult
parseEnumKeywordAttr(EnumClass & value,OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())215 parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
216                      OperationState &state,
217                      StringRef attrName = spirv::attributeName<EnumClass>()) {
218   if (parseEnumKeywordAttr(value, parser)) {
219     return failure();
220   }
221   state.addAttribute(attrName, parser.getBuilder().getI32IntegerAttr(
222                                    llvm::bit_cast<int32_t>(value)));
223   return success();
224 }
225 
226 /// Parses Function, Selection and Loop control attributes. If no control is
227 /// specified, "None" is used as a default.
228 template <typename EnumClass>
229 static ParseResult
parseControlAttribute(OpAsmParser & parser,OperationState & state,StringRef attrName=spirv::attributeName<EnumClass> ())230 parseControlAttribute(OpAsmParser &parser, OperationState &state,
231                       StringRef attrName = spirv::attributeName<EnumClass>()) {
232   if (succeeded(parser.parseOptionalKeyword(kControl))) {
233     EnumClass control;
234     if (parser.parseLParen() || parseEnumKeywordAttr(control, parser, state) ||
235         parser.parseRParen())
236       return failure();
237     return success();
238   }
239   // Set control to "None" otherwise.
240   Builder builder = parser.getBuilder();
241   state.addAttribute(attrName, builder.getI32IntegerAttr(0));
242   return success();
243 }
244 
245 /// Parses optional memory access attributes attached to a memory access
246 /// operand/pointer. Specifically, parses the following syntax:
247 ///     (`[` memory-access `]`)?
248 /// where:
249 ///     memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
250 ///         integer-literal | `"NonTemporal"`
parseMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)251 static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
252                                                OperationState &state) {
253   // Parse an optional list of attributes staring with '['
254   if (parser.parseOptionalLSquare()) {
255     // Nothing to do
256     return success();
257   }
258 
259   spirv::MemoryAccess memoryAccessAttr;
260   if (parseEnumStrAttr(memoryAccessAttr, parser, state,
261                        kMemoryAccessAttrName)) {
262     return failure();
263   }
264 
265   if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
266     // Parse integer attribute for alignment.
267     Attribute alignmentAttr;
268     Type i32Type = parser.getBuilder().getIntegerType(32);
269     if (parser.parseComma() ||
270         parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
271                               state.attributes)) {
272       return failure();
273     }
274   }
275   return parser.parseRSquare();
276 }
277 
278 // TODO Make sure to merge this and the previous function into one template
279 // parameterized by memory access attribute name and alignment. Doing so now
280 // results in VS2017 in producing an internal error (at the call site) that's
281 // not detailed enough to understand what is happening.
parseSourceMemoryAccessAttributes(OpAsmParser & parser,OperationState & state)282 static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
283                                                      OperationState &state) {
284   // Parse an optional list of attributes staring with '['
285   if (parser.parseOptionalLSquare()) {
286     // Nothing to do
287     return success();
288   }
289 
290   spirv::MemoryAccess memoryAccessAttr;
291   if (parseEnumStrAttr(memoryAccessAttr, parser, state,
292                        kSourceMemoryAccessAttrName)) {
293     return failure();
294   }
295 
296   if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
297     // Parse integer attribute for alignment.
298     Attribute alignmentAttr;
299     Type i32Type = parser.getBuilder().getIntegerType(32);
300     if (parser.parseComma() ||
301         parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
302                               state.attributes)) {
303       return failure();
304     }
305   }
306   return parser.parseRSquare();
307 }
308 
309 template <typename MemoryOpTy>
printMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)310 static void printMemoryAccessAttribute(
311     MemoryOpTy memoryOp, OpAsmPrinter &printer,
312     SmallVectorImpl<StringRef> &elidedAttrs,
313     Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
314     Optional<uint32_t> alignmentAttrValue = None) {
315   // Print optional memory access attribute.
316   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
317                                               : memoryOp.memory_access())) {
318     elidedAttrs.push_back(kMemoryAccessAttrName);
319 
320     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
321 
322     if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
323       // Print integer alignment attribute.
324       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
325                                                : memoryOp.alignment())) {
326         elidedAttrs.push_back(kAlignmentAttrName);
327         printer << ", " << alignment;
328       }
329     }
330     printer << "]";
331   }
332   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
333 }
334 
335 // TODO Make sure to merge this and the previous function into one template
336 // parameterized by memory access attribute name and alignment. Doing so now
337 // results in VS2017 in producing an internal error (at the call site) that's
338 // not detailed enough to understand what is happening.
339 template <typename MemoryOpTy>
printSourceMemoryAccessAttribute(MemoryOpTy memoryOp,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs,Optional<spirv::MemoryAccess> memoryAccessAtrrValue=None,Optional<uint32_t> alignmentAttrValue=None)340 static void printSourceMemoryAccessAttribute(
341     MemoryOpTy memoryOp, OpAsmPrinter &printer,
342     SmallVectorImpl<StringRef> &elidedAttrs,
343     Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
344     Optional<uint32_t> alignmentAttrValue = None) {
345 
346   printer << ", ";
347 
348   // Print optional memory access attribute.
349   if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
350                                               : memoryOp.memory_access())) {
351     elidedAttrs.push_back(kSourceMemoryAccessAttrName);
352 
353     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
354 
355     if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
356       // Print integer alignment attribute.
357       if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
358                                                : memoryOp.alignment())) {
359         elidedAttrs.push_back(kSourceAlignmentAttrName);
360         printer << ", " << alignment;
361       }
362     }
363     printer << "]";
364   }
365   elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
366 }
367 
parseImageOperands(OpAsmParser & parser,spirv::ImageOperandsAttr & attr)368 static ParseResult parseImageOperands(OpAsmParser &parser,
369                                       spirv::ImageOperandsAttr &attr) {
370   // Expect image operands
371   if (parser.parseOptionalLSquare())
372     return success();
373 
374   spirv::ImageOperands imageOperands;
375   if (parseEnumStrAttr(imageOperands, parser))
376     return failure();
377 
378   attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
379 
380   return parser.parseRSquare();
381 }
382 
printImageOperands(OpAsmPrinter & printer,Operation * imageOp,spirv::ImageOperandsAttr attr)383 static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
384                                spirv::ImageOperandsAttr attr) {
385   if (attr) {
386     auto strImageOperands = stringifyImageOperands(attr.getValue());
387     printer << "[\"" << strImageOperands << "\"]";
388   }
389 }
390 
391 template <typename Op>
verifyImageOperands(Op imageOp,spirv::ImageOperandsAttr attr,Operation::operand_range operands)392 static LogicalResult verifyImageOperands(Op imageOp,
393                                          spirv::ImageOperandsAttr attr,
394                                          Operation::operand_range operands) {
395   if (!attr) {
396     if (operands.empty())
397       return success();
398 
399     return imageOp.emitError("the Image Operands should encode what operands "
400                              "follow, as per Image Operands");
401   }
402 
403   // TODO: Add the validation rules for the following Image Operands.
404   spirv::ImageOperands noSupportOperands =
405       spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
406       spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
407       spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
408       spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
409       spirv::ImageOperands::MakeTexelAvailable |
410       spirv::ImageOperands::MakeTexelVisible |
411       spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
412 
413   if (spirv::bitEnumContains(attr.getValue(), noSupportOperands))
414     llvm_unreachable("unimplemented operands of Image Operands");
415 
416   return success();
417 }
418 
verifyCastOp(Operation * op,bool requireSameBitWidth=true,bool skipBitWidthCheck=false)419 static LogicalResult verifyCastOp(Operation *op,
420                                   bool requireSameBitWidth = true,
421                                   bool skipBitWidthCheck = false) {
422   // Some CastOps have no limit on bit widths for result and operand type.
423   if (skipBitWidthCheck)
424     return success();
425 
426   Type operandType = op->getOperand(0).getType();
427   Type resultType = op->getResult(0).getType();
428 
429   // ODS checks that result type and operand type have the same shape.
430   if (auto vectorType = operandType.dyn_cast<VectorType>()) {
431     operandType = vectorType.getElementType();
432     resultType = resultType.cast<VectorType>().getElementType();
433   }
434 
435   if (auto coopMatrixType =
436           operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
437     operandType = coopMatrixType.getElementType();
438     resultType =
439         resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
440   }
441 
442   auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
443   auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
444   auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
445 
446   if (requireSameBitWidth) {
447     if (!isSameBitWidth) {
448       return op->emitOpError(
449                  "expected the same bit widths for operand type and result "
450                  "type, but provided ")
451              << operandType << " and " << resultType;
452     }
453     return success();
454   }
455 
456   if (isSameBitWidth) {
457     return op->emitOpError(
458                "expected the different bit widths for operand type and result "
459                "type, but provided ")
460            << operandType << " and " << resultType;
461   }
462   return success();
463 }
464 
465 template <typename MemoryOpTy>
verifyMemoryAccessAttribute(MemoryOpTy memoryOp)466 static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
467   // ODS checks for attributes values. Just need to verify that if the
468   // memory-access attribute is Aligned, then the alignment attribute must be
469   // present.
470   auto *op = memoryOp.getOperation();
471   auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
472   if (!memAccessAttr) {
473     // Alignment attribute shouldn't be present if memory access attribute is
474     // not present.
475     if (op->getAttr(kAlignmentAttrName)) {
476       return memoryOp.emitOpError(
477           "invalid alignment specification without aligned memory access "
478           "specification");
479     }
480     return success();
481   }
482 
483   auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
484   auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
485 
486   if (!memAccess) {
487     return memoryOp.emitOpError("invalid memory access specifier: ")
488            << memAccessVal;
489   }
490 
491   if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
492     if (!op->getAttr(kAlignmentAttrName)) {
493       return memoryOp.emitOpError("missing alignment value");
494     }
495   } else {
496     if (op->getAttr(kAlignmentAttrName)) {
497       return memoryOp.emitOpError(
498           "invalid alignment specification with non-aligned memory access "
499           "specification");
500     }
501   }
502   return success();
503 }
504 
505 // TODO Make sure to merge this and the previous function into one template
506 // parameterized by memory access attribute name and alignment. Doing so now
507 // results in VS2017 in producing an internal error (at the call site) that's
508 // not detailed enough to understand what is happening.
509 template <typename MemoryOpTy>
verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp)510 static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
511   // ODS checks for attributes values. Just need to verify that if the
512   // memory-access attribute is Aligned, then the alignment attribute must be
513   // present.
514   auto *op = memoryOp.getOperation();
515   auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
516   if (!memAccessAttr) {
517     // Alignment attribute shouldn't be present if memory access attribute is
518     // not present.
519     if (op->getAttr(kSourceAlignmentAttrName)) {
520       return memoryOp.emitOpError(
521           "invalid alignment specification without aligned memory access "
522           "specification");
523     }
524     return success();
525   }
526 
527   auto memAccessVal = memAccessAttr.template cast<IntegerAttr>();
528   auto memAccess = spirv::symbolizeMemoryAccess(memAccessVal.getInt());
529 
530   if (!memAccess) {
531     return memoryOp.emitOpError("invalid memory access specifier: ")
532            << memAccessVal;
533   }
534 
535   if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
536     if (!op->getAttr(kSourceAlignmentAttrName)) {
537       return memoryOp.emitOpError("missing alignment value");
538     }
539   } else {
540     if (op->getAttr(kSourceAlignmentAttrName)) {
541       return memoryOp.emitOpError(
542           "invalid alignment specification with non-aligned memory access "
543           "specification");
544     }
545   }
546   return success();
547 }
548 
549 static LogicalResult
verifyMemorySemantics(Operation * op,spirv::MemorySemantics memorySemantics)550 verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
551   // According to the SPIR-V specification:
552   // "Despite being a mask and allowing multiple bits to be combined, it is
553   // invalid for more than one of these four bits to be set: Acquire, Release,
554   // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
555   // Release semantics is done by setting the AcquireRelease bit, not by setting
556   // two bits."
557   auto atMostOneInSet = spirv::MemorySemantics::Acquire |
558                         spirv::MemorySemantics::Release |
559                         spirv::MemorySemantics::AcquireRelease |
560                         spirv::MemorySemantics::SequentiallyConsistent;
561 
562   auto bitCount = llvm::countPopulation(
563       static_cast<uint32_t>(memorySemantics & atMostOneInSet));
564   if (bitCount > 1) {
565     return op->emitError(
566         "expected at most one of these four memory constraints "
567         "to be set: `Acquire`, `Release`,"
568         "`AcquireRelease` or `SequentiallyConsistent`");
569   }
570   return success();
571 }
572 
573 template <typename LoadStoreOpTy>
verifyLoadStorePtrAndValTypes(LoadStoreOpTy op,Value ptr,Value val)574 static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
575                                                    Value val) {
576   // ODS already checks ptr is spirv::PointerType. Just check that the pointee
577   // type of the pointer and the type of the value are the same
578   //
579   // TODO: Check that the value type satisfies restrictions of
580   // SPIR-V OpLoad/OpStore operations
581   if (val.getType() !=
582       ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
583     return op.emitOpError("mismatch in result type and pointer type");
584   }
585   return success();
586 }
587 
588 template <typename BlockReadWriteOpTy>
verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,Value ptr,Value val)589 static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
590                                                         Value ptr, Value val) {
591   auto valType = val.getType();
592   if (auto valVecTy = valType.dyn_cast<VectorType>())
593     valType = valVecTy.getElementType();
594 
595   if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
596     return op.emitOpError("mismatch in result type and pointer type");
597   }
598   return success();
599 }
600 
parseVariableDecorations(OpAsmParser & parser,OperationState & state)601 static ParseResult parseVariableDecorations(OpAsmParser &parser,
602                                             OperationState &state) {
603   auto builtInName = llvm::convertToSnakeFromCamelCase(
604       stringifyDecoration(spirv::Decoration::BuiltIn));
605   if (succeeded(parser.parseOptionalKeyword("bind"))) {
606     Attribute set, binding;
607     // Parse optional descriptor binding
608     auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
609         stringifyDecoration(spirv::Decoration::DescriptorSet));
610     auto bindingName = llvm::convertToSnakeFromCamelCase(
611         stringifyDecoration(spirv::Decoration::Binding));
612     Type i32Type = parser.getBuilder().getIntegerType(32);
613     if (parser.parseLParen() ||
614         parser.parseAttribute(set, i32Type, descriptorSetName,
615                               state.attributes) ||
616         parser.parseComma() ||
617         parser.parseAttribute(binding, i32Type, bindingName,
618                               state.attributes) ||
619         parser.parseRParen()) {
620       return failure();
621     }
622   } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
623     StringAttr builtIn;
624     if (parser.parseLParen() ||
625         parser.parseAttribute(builtIn, builtInName, state.attributes) ||
626         parser.parseRParen()) {
627       return failure();
628     }
629   }
630 
631   // Parse other attributes
632   if (parser.parseOptionalAttrDict(state.attributes))
633     return failure();
634 
635   return success();
636 }
637 
printVariableDecorations(Operation * op,OpAsmPrinter & printer,SmallVectorImpl<StringRef> & elidedAttrs)638 static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
639                                      SmallVectorImpl<StringRef> &elidedAttrs) {
640   // Print optional descriptor binding
641   auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
642       stringifyDecoration(spirv::Decoration::DescriptorSet));
643   auto bindingName = llvm::convertToSnakeFromCamelCase(
644       stringifyDecoration(spirv::Decoration::Binding));
645   auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
646   auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
647   if (descriptorSet && binding) {
648     elidedAttrs.push_back(descriptorSetName);
649     elidedAttrs.push_back(bindingName);
650     printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
651             << ")";
652   }
653 
654   // Print BuiltIn attribute if present
655   auto builtInName = llvm::convertToSnakeFromCamelCase(
656       stringifyDecoration(spirv::Decoration::BuiltIn));
657   if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
658     printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
659     elidedAttrs.push_back(builtInName);
660   }
661 
662   printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
663 }
664 
665 // Get bit width of types.
getBitWidth(Type type)666 static unsigned getBitWidth(Type type) {
667   if (type.isa<spirv::PointerType>()) {
668     // Just return 64 bits for pointer types for now.
669     // TODO: Make sure not caller relies on the actual pointer width value.
670     return 64;
671   }
672 
673   if (type.isIntOrFloat())
674     return type.getIntOrFloatBitWidth();
675 
676   if (auto vectorType = type.dyn_cast<VectorType>()) {
677     assert(vectorType.getElementType().isIntOrFloat());
678     return vectorType.getNumElements() *
679            vectorType.getElementType().getIntOrFloatBitWidth();
680   }
681   llvm_unreachable("unhandled bit width computation for type");
682 }
683 
684 /// Walks the given type hierarchy with the given indices, potentially down
685 /// to component granularity, to select an element type. Returns null type and
686 /// emits errors with the given loc on failure.
687 static Type
getElementType(Type type,ArrayRef<int32_t> indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)688 getElementType(Type type, ArrayRef<int32_t> indices,
689                function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
690   if (indices.empty()) {
691     emitErrorFn("expected at least one index for spv.CompositeExtract");
692     return nullptr;
693   }
694 
695   for (auto index : indices) {
696     if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
697       if (cType.hasCompileTimeKnownNumElements() &&
698           (index < 0 ||
699            static_cast<uint64_t>(index) >= cType.getNumElements())) {
700         emitErrorFn("index ") << index << " out of bounds for " << type;
701         return nullptr;
702       }
703       type = cType.getElementType(index);
704     } else {
705       emitErrorFn("cannot extract from non-composite type ")
706           << type << " with index " << index;
707       return nullptr;
708     }
709   }
710   return type;
711 }
712 
713 static Type
getElementType(Type type,Attribute indices,function_ref<InFlightDiagnostic (StringRef)> emitErrorFn)714 getElementType(Type type, Attribute indices,
715                function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
716   auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
717   if (!indicesArrayAttr) {
718     emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
719     return nullptr;
720   }
721   if (indicesArrayAttr.empty()) {
722     emitErrorFn("expected at least one index for spv.CompositeExtract");
723     return nullptr;
724   }
725 
726   SmallVector<int32_t, 2> indexVals;
727   for (auto indexAttr : indicesArrayAttr) {
728     auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
729     if (!indexIntAttr) {
730       emitErrorFn("expected an 32-bit integer for index, but found '")
731           << indexAttr << "'";
732       return nullptr;
733     }
734     indexVals.push_back(indexIntAttr.getInt());
735   }
736   return getElementType(type, indexVals, emitErrorFn);
737 }
738 
getElementType(Type type,Attribute indices,Location loc)739 static Type getElementType(Type type, Attribute indices, Location loc) {
740   auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
741     return ::mlir::emitError(loc, err);
742   };
743   return getElementType(type, indices, errorFn);
744 }
745 
getElementType(Type type,Attribute indices,OpAsmParser & parser,SMLoc loc)746 static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
747                            SMLoc loc) {
748   auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
749     return parser.emitError(loc, err);
750   };
751   return getElementType(type, indices, errorFn);
752 }
753 
754 /// Returns true if the given `block` only contains one `spv.mlir.merge` op.
isMergeBlock(Block & block)755 static inline bool isMergeBlock(Block &block) {
756   return !block.empty() && std::next(block.begin()) == block.end() &&
757          isa<spirv::MergeOp>(block.front());
758 }
759 
760 //===----------------------------------------------------------------------===//
761 // Common parsers and printers
762 //===----------------------------------------------------------------------===//
763 
764 // Parses an atomic update op. If the update op does not take a value (like
765 // AtomicIIncrement) `hasValue` must be false.
parseAtomicUpdateOp(OpAsmParser & parser,OperationState & state,bool hasValue)766 static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
767                                        OperationState &state, bool hasValue) {
768   spirv::Scope scope;
769   spirv::MemorySemantics memoryScope;
770   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
771   OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
772   Type type;
773   SMLoc loc;
774   if (parseEnumStrAttr(scope, parser, state, kMemoryScopeAttrName) ||
775       parseEnumStrAttr(memoryScope, parser, state, kSemanticsAttrName) ||
776       parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
777       parser.getCurrentLocation(&loc) || parser.parseColonType(type))
778     return failure();
779 
780   auto ptrType = type.dyn_cast<spirv::PointerType>();
781   if (!ptrType)
782     return parser.emitError(loc, "expected pointer type");
783 
784   SmallVector<Type, 2> operandTypes;
785   operandTypes.push_back(ptrType);
786   if (hasValue)
787     operandTypes.push_back(ptrType.getPointeeType());
788   if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
789                              state.operands))
790     return failure();
791   return parser.addTypeToList(ptrType.getPointeeType(), state.types);
792 }
793 
794 // Prints an atomic update op.
printAtomicUpdateOp(Operation * op,OpAsmPrinter & printer)795 static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
796   printer << " \"";
797   auto scopeAttr = op->getAttrOfType<IntegerAttr>(kMemoryScopeAttrName);
798   printer << spirv::stringifyScope(
799                  static_cast<spirv::Scope>(scopeAttr.getInt()))
800           << "\" \"";
801   auto memorySemanticsAttr = op->getAttrOfType<IntegerAttr>(kSemanticsAttrName);
802   printer << spirv::stringifyMemorySemantics(
803                  static_cast<spirv::MemorySemantics>(
804                      memorySemanticsAttr.getInt()))
805           << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
806 }
807 
808 template <typename T>
809 static StringRef stringifyTypeName();
810 
811 template <>
stringifyTypeName()812 StringRef stringifyTypeName<IntegerType>() {
813   return "integer";
814 }
815 
816 template <>
stringifyTypeName()817 StringRef stringifyTypeName<FloatType>() {
818   return "float";
819 }
820 
821 // Verifies an atomic update op.
822 template <typename ExpectedElementType>
verifyAtomicUpdateOp(Operation * op)823 static LogicalResult verifyAtomicUpdateOp(Operation *op) {
824   auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
825   auto elementType = ptrType.getPointeeType();
826   if (!elementType.isa<ExpectedElementType>())
827     return op->emitOpError() << "pointer operand must point to an "
828                              << stringifyTypeName<ExpectedElementType>()
829                              << " value, found " << elementType;
830 
831   if (op->getNumOperands() > 1) {
832     auto valueType = op->getOperand(1).getType();
833     if (valueType != elementType)
834       return op->emitOpError("expected value to have the same type as the "
835                              "pointer operand's pointee type ")
836              << elementType << ", but found " << valueType;
837   }
838   auto memorySemantics = static_cast<spirv::MemorySemantics>(
839       op->getAttrOfType<IntegerAttr>(kSemanticsAttrName).getInt());
840   if (failed(verifyMemorySemantics(op, memorySemantics))) {
841     return failure();
842   }
843   return success();
844 }
845 
parseGroupNonUniformArithmeticOp(OpAsmParser & parser,OperationState & state)846 static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
847                                                     OperationState &state) {
848   spirv::Scope executionScope;
849   spirv::GroupOperation groupOperation;
850   OpAsmParser::UnresolvedOperand valueInfo;
851   if (parseEnumStrAttr(executionScope, parser, state,
852                        kExecutionScopeAttrName) ||
853       parseEnumStrAttr(groupOperation, parser, state,
854                        kGroupOperationAttrName) ||
855       parser.parseOperand(valueInfo))
856     return failure();
857 
858   Optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
859   if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
860     clusterSizeInfo = OpAsmParser::UnresolvedOperand();
861     if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
862         parser.parseRParen())
863       return failure();
864   }
865 
866   Type resultType;
867   if (parser.parseColonType(resultType))
868     return failure();
869 
870   if (parser.resolveOperand(valueInfo, resultType, state.operands))
871     return failure();
872 
873   if (clusterSizeInfo) {
874     Type i32Type = parser.getBuilder().getIntegerType(32);
875     if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
876       return failure();
877   }
878 
879   return parser.addTypeToList(resultType, state.types);
880 }
881 
printGroupNonUniformArithmeticOp(Operation * groupOp,OpAsmPrinter & printer)882 static void printGroupNonUniformArithmeticOp(Operation *groupOp,
883                                              OpAsmPrinter &printer) {
884   printer << " \""
885           << stringifyScope(static_cast<spirv::Scope>(
886                  groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName)
887                      .getInt()))
888           << "\" \""
889           << stringifyGroupOperation(static_cast<spirv::GroupOperation>(
890                  groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName)
891                      .getInt()))
892           << "\" " << groupOp->getOperand(0);
893 
894   if (groupOp->getNumOperands() > 1)
895     printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
896   printer << " : " << groupOp->getResult(0).getType();
897 }
898 
verifyGroupNonUniformArithmeticOp(Operation * groupOp)899 static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
900   spirv::Scope scope = static_cast<spirv::Scope>(
901       groupOp->getAttrOfType<IntegerAttr>(kExecutionScopeAttrName).getInt());
902   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
903     return groupOp->emitOpError(
904         "execution scope must be 'Workgroup' or 'Subgroup'");
905 
906   spirv::GroupOperation operation = static_cast<spirv::GroupOperation>(
907       groupOp->getAttrOfType<IntegerAttr>(kGroupOperationAttrName).getInt());
908   if (operation == spirv::GroupOperation::ClusteredReduce &&
909       groupOp->getNumOperands() == 1)
910     return groupOp->emitOpError("cluster size operand must be provided for "
911                                 "'ClusteredReduce' group operation");
912   if (groupOp->getNumOperands() > 1) {
913     Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
914     int32_t clusterSize = 0;
915 
916     // TODO: support specialization constant here.
917     if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
918       return groupOp->emitOpError(
919           "cluster size operand must come from a constant op");
920 
921     if (!llvm::isPowerOf2_32(clusterSize))
922       return groupOp->emitOpError(
923           "cluster size operand must be a power of two");
924   }
925   return success();
926 }
927 
928 /// Result of a logical op must be a scalar or vector of boolean type.
getUnaryOpResultType(Type operandType)929 static Type getUnaryOpResultType(Type operandType) {
930   Builder builder(operandType.getContext());
931   Type resultType = builder.getIntegerType(1);
932   if (auto vecType = operandType.dyn_cast<VectorType>())
933     return VectorType::get(vecType.getNumElements(), resultType);
934   return resultType;
935 }
936 
verifyShiftOp(Operation * op)937 static LogicalResult verifyShiftOp(Operation *op) {
938   if (op->getOperand(0).getType() != op->getResult(0).getType()) {
939     return op->emitError("expected the same type for the first operand and "
940                          "result, but provided ")
941            << op->getOperand(0).getType() << " and "
942            << op->getResult(0).getType();
943   }
944   return success();
945 }
946 
buildLogicalBinaryOp(OpBuilder & builder,OperationState & state,Value lhs,Value rhs)947 static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
948                                  Value lhs, Value rhs) {
949   assert(lhs.getType() == rhs.getType());
950 
951   Type boolType = builder.getI1Type();
952   if (auto vecType = lhs.getType().dyn_cast<VectorType>())
953     boolType = VectorType::get(vecType.getShape(), boolType);
954   state.addTypes(boolType);
955 
956   state.addOperands({lhs, rhs});
957 }
958 
buildLogicalUnaryOp(OpBuilder & builder,OperationState & state,Value value)959 static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
960                                 Value value) {
961   Type boolType = builder.getI1Type();
962   if (auto vecType = value.getType().dyn_cast<VectorType>())
963     boolType = VectorType::get(vecType.getShape(), boolType);
964   state.addTypes(boolType);
965 
966   state.addOperands(value);
967 }
968 
969 //===----------------------------------------------------------------------===//
970 // spv.AccessChainOp
971 //===----------------------------------------------------------------------===//
972 
getElementPtrType(Type type,ValueRange indices,Location baseLoc)973 static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
974   auto ptrType = type.dyn_cast<spirv::PointerType>();
975   if (!ptrType) {
976     emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
977                        "to composite type, but provided ")
978         << type;
979     return nullptr;
980   }
981 
982   auto resultType = ptrType.getPointeeType();
983   auto resultStorageClass = ptrType.getStorageClass();
984   int32_t index = 0;
985 
986   for (auto indexSSA : indices) {
987     auto cType = resultType.dyn_cast<spirv::CompositeType>();
988     if (!cType) {
989       emitError(baseLoc,
990                 "'spv.AccessChain' op cannot extract from non-composite type ")
991           << resultType << " with index " << index;
992       return nullptr;
993     }
994     index = 0;
995     if (resultType.isa<spirv::StructType>()) {
996       Operation *op = indexSSA.getDefiningOp();
997       if (!op) {
998         emitError(baseLoc, "'spv.AccessChain' op index must be an "
999                            "integer spv.Constant to access "
1000                            "element of spv.struct");
1001         return nullptr;
1002       }
1003 
1004       // TODO: this should be relaxed to allow
1005       // integer literals of other bitwidths.
1006       if (failed(extractValueFromConstOp(op, index))) {
1007         emitError(baseLoc,
1008                   "'spv.AccessChain' index must be an integer spv.Constant to "
1009                   "access element of spv.struct, but provided ")
1010             << op->getName();
1011         return nullptr;
1012       }
1013       if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
1014         emitError(baseLoc, "'spv.AccessChain' op index ")
1015             << index << " out of bounds for " << resultType;
1016         return nullptr;
1017       }
1018     }
1019     resultType = cType.getElementType(index);
1020   }
1021   return spirv::PointerType::get(resultType, resultStorageClass);
1022 }
1023 
build(OpBuilder & builder,OperationState & state,Value basePtr,ValueRange indices)1024 void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
1025                                  Value basePtr, ValueRange indices) {
1026   auto type = getElementPtrType(basePtr.getType(), indices, state.location);
1027   assert(type && "Unable to deduce return type based on basePtr and indices");
1028   build(builder, state, type, basePtr, indices);
1029 }
1030 
parse(OpAsmParser & parser,OperationState & state)1031 ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
1032                                         OperationState &state) {
1033   OpAsmParser::UnresolvedOperand ptrInfo;
1034   SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
1035   Type type;
1036   auto loc = parser.getCurrentLocation();
1037   SmallVector<Type, 4> indicesTypes;
1038 
1039   if (parser.parseOperand(ptrInfo) ||
1040       parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
1041       parser.parseColonType(type) ||
1042       parser.resolveOperand(ptrInfo, type, state.operands)) {
1043     return failure();
1044   }
1045 
1046   // Check that the provided indices list is not empty before parsing their
1047   // type list.
1048   if (indicesInfo.empty()) {
1049     return mlir::emitError(state.location, "'spv.AccessChain' op expected at "
1050                                            "least one index ");
1051   }
1052 
1053   if (parser.parseComma() || parser.parseTypeList(indicesTypes))
1054     return failure();
1055 
1056   // Check that the indices types list is not empty and that it has a one-to-one
1057   // mapping to the provided indices.
1058   if (indicesTypes.size() != indicesInfo.size()) {
1059     return mlir::emitError(state.location,
1060                            "'spv.AccessChain' op indices types' count must be "
1061                            "equal to indices info count");
1062   }
1063 
1064   if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
1065     return failure();
1066 
1067   auto resultType = getElementPtrType(
1068       type, llvm::makeArrayRef(state.operands).drop_front(), state.location);
1069   if (!resultType) {
1070     return failure();
1071   }
1072 
1073   state.addTypes(resultType);
1074   return success();
1075 }
1076 
1077 template <typename Op>
printAccessChain(Op op,ValueRange indices,OpAsmPrinter & printer)1078 static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
1079   printer << ' ' << op.base_ptr() << '[' << indices
1080           << "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
1081 }
1082 
print(OpAsmPrinter & printer)1083 void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
1084   printAccessChain(*this, indices(), printer);
1085 }
1086 
1087 template <typename Op>
verifyAccessChain(Op accessChainOp,ValueRange indices)1088 static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
1089   auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
1090                                       indices, accessChainOp.getLoc());
1091   if (!resultType)
1092     return failure();
1093 
1094   auto providedResultType =
1095       accessChainOp.getType().template dyn_cast<spirv::PointerType>();
1096   if (!providedResultType)
1097     return accessChainOp.emitOpError(
1098                "result type must be a pointer, but provided")
1099            << providedResultType;
1100 
1101   if (resultType != providedResultType)
1102     return accessChainOp.emitOpError("invalid result type: expected ")
1103            << resultType << ", but provided " << providedResultType;
1104 
1105   return success();
1106 }
1107 
verify()1108 LogicalResult spirv::AccessChainOp::verify() {
1109   return verifyAccessChain(*this, indices());
1110 }
1111 
1112 //===----------------------------------------------------------------------===//
1113 // spv.mlir.addressof
1114 //===----------------------------------------------------------------------===//
1115 
build(OpBuilder & builder,OperationState & state,spirv::GlobalVariableOp var)1116 void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1117                                spirv::GlobalVariableOp var) {
1118   build(builder, state, var.type(), SymbolRefAttr::get(var));
1119 }
1120 
verify()1121 LogicalResult spirv::AddressOfOp::verify() {
1122   auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1123       SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
1124                                            variableAttr()));
1125   if (!varOp) {
1126     return emitOpError("expected spv.GlobalVariable symbol");
1127   }
1128   if (pointer().getType() != varOp.type()) {
1129     return emitOpError(
1130         "result type mismatch with the referenced global variable's type");
1131   }
1132   return success();
1133 }
1134 
1135 template <typename T>
printAtomicCompareExchangeImpl(T atomOp,OpAsmPrinter & printer)1136 static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
1137   printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
1138           << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
1139           << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
1140           << atomOp.getOperands() << " : " << atomOp.pointer().getType();
1141 }
1142 
parseAtomicCompareExchangeImpl(OpAsmParser & parser,OperationState & state)1143 static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
1144                                                   OperationState &state) {
1145   spirv::Scope memoryScope;
1146   spirv::MemorySemantics equalSemantics, unequalSemantics;
1147   SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
1148   Type type;
1149   if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1150       parseEnumStrAttr(equalSemantics, parser, state,
1151                        kEqualSemanticsAttrName) ||
1152       parseEnumStrAttr(unequalSemantics, parser, state,
1153                        kUnequalSemanticsAttrName) ||
1154       parser.parseOperandList(operandInfo, 3))
1155     return failure();
1156 
1157   auto loc = parser.getCurrentLocation();
1158   if (parser.parseColonType(type))
1159     return failure();
1160 
1161   auto ptrType = type.dyn_cast<spirv::PointerType>();
1162   if (!ptrType)
1163     return parser.emitError(loc, "expected pointer type");
1164 
1165   if (parser.resolveOperands(
1166           operandInfo,
1167           {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1168           parser.getNameLoc(), state.operands))
1169     return failure();
1170 
1171   return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1172 }
1173 
1174 template <typename T>
verifyAtomicCompareExchangeImpl(T atomOp)1175 static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
1176   // According to the spec:
1177   // "The type of Value must be the same as Result Type. The type of the value
1178   // pointed to by Pointer must be the same as Result Type. This type must also
1179   // match the type of Comparator."
1180   if (atomOp.getType() != atomOp.value().getType())
1181     return atomOp.emitOpError("value operand must have the same type as the op "
1182                               "result, but found ")
1183            << atomOp.value().getType() << " vs " << atomOp.getType();
1184 
1185   if (atomOp.getType() != atomOp.comparator().getType())
1186     return atomOp.emitOpError(
1187                "comparator operand must have the same type as the op "
1188                "result, but found ")
1189            << atomOp.comparator().getType() << " vs " << atomOp.getType();
1190 
1191   Type pointeeType = atomOp.pointer()
1192                          .getType()
1193                          .template cast<spirv::PointerType>()
1194                          .getPointeeType();
1195   if (atomOp.getType() != pointeeType)
1196     return atomOp.emitOpError(
1197                "pointer operand's pointee type must have the same "
1198                "as the op result type, but found ")
1199            << pointeeType << " vs " << atomOp.getType();
1200 
1201   // TODO: Unequal cannot be set to Release or Acquire and Release.
1202   // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1203 
1204   return success();
1205 }
1206 
1207 //===----------------------------------------------------------------------===//
1208 // spv.AtomicAndOp
1209 //===----------------------------------------------------------------------===//
1210 
verify()1211 LogicalResult spirv::AtomicAndOp::verify() {
1212   return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1213 }
1214 
parse(OpAsmParser & parser,OperationState & result)1215 ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser,
1216                                       OperationState &result) {
1217   return ::parseAtomicUpdateOp(parser, result, true);
1218 }
print(OpAsmPrinter & p)1219 void spirv::AtomicAndOp::print(OpAsmPrinter &p) {
1220   ::printAtomicUpdateOp(*this, p);
1221 }
1222 
1223 //===----------------------------------------------------------------------===//
1224 // spv.AtomicCompareExchangeOp
1225 //===----------------------------------------------------------------------===//
1226 
verify()1227 LogicalResult spirv::AtomicCompareExchangeOp::verify() {
1228   return ::verifyAtomicCompareExchangeImpl(*this);
1229 }
1230 
parse(OpAsmParser & parser,OperationState & result)1231 ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser,
1232                                                   OperationState &result) {
1233   return ::parseAtomicCompareExchangeImpl(parser, result);
1234 }
print(OpAsmPrinter & p)1235 void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
1236   ::printAtomicCompareExchangeImpl(*this, p);
1237 }
1238 
1239 //===----------------------------------------------------------------------===//
1240 // spv.AtomicCompareExchangeWeakOp
1241 //===----------------------------------------------------------------------===//
1242 
verify()1243 LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() {
1244   return ::verifyAtomicCompareExchangeImpl(*this);
1245 }
1246 
parse(OpAsmParser & parser,OperationState & result)1247 ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
1248                                                       OperationState &result) {
1249   return ::parseAtomicCompareExchangeImpl(parser, result);
1250 }
print(OpAsmPrinter & p)1251 void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
1252   ::printAtomicCompareExchangeImpl(*this, p);
1253 }
1254 
1255 //===----------------------------------------------------------------------===//
1256 // spv.AtomicExchange
1257 //===----------------------------------------------------------------------===//
1258 
print(OpAsmPrinter & printer)1259 void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) {
1260   printer << " \"" << stringifyScope(memory_scope()) << "\" \""
1261           << stringifyMemorySemantics(semantics()) << "\" " << getOperands()
1262           << " : " << pointer().getType();
1263 }
1264 
parse(OpAsmParser & parser,OperationState & state)1265 ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
1266                                            OperationState &state) {
1267   spirv::Scope memoryScope;
1268   spirv::MemorySemantics semantics;
1269   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
1270   Type type;
1271   if (parseEnumStrAttr(memoryScope, parser, state, kMemoryScopeAttrName) ||
1272       parseEnumStrAttr(semantics, parser, state, kSemanticsAttrName) ||
1273       parser.parseOperandList(operandInfo, 2))
1274     return failure();
1275 
1276   auto loc = parser.getCurrentLocation();
1277   if (parser.parseColonType(type))
1278     return failure();
1279 
1280   auto ptrType = type.dyn_cast<spirv::PointerType>();
1281   if (!ptrType)
1282     return parser.emitError(loc, "expected pointer type");
1283 
1284   if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
1285                              parser.getNameLoc(), state.operands))
1286     return failure();
1287 
1288   return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1289 }
1290 
verify()1291 LogicalResult spirv::AtomicExchangeOp::verify() {
1292   if (getType() != value().getType())
1293     return emitOpError("value operand must have the same type as the op "
1294                        "result, but found ")
1295            << value().getType() << " vs " << getType();
1296 
1297   Type pointeeType =
1298       pointer().getType().cast<spirv::PointerType>().getPointeeType();
1299   if (getType() != pointeeType)
1300     return emitOpError("pointer operand's pointee type must have the same "
1301                        "as the op result type, but found ")
1302            << pointeeType << " vs " << getType();
1303 
1304   return success();
1305 }
1306 
1307 //===----------------------------------------------------------------------===//
1308 // spv.AtomicIAddOp
1309 //===----------------------------------------------------------------------===//
1310 
verify()1311 LogicalResult spirv::AtomicIAddOp::verify() {
1312   return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1313 }
1314 
parse(OpAsmParser & parser,OperationState & result)1315 ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser,
1316                                        OperationState &result) {
1317   return ::parseAtomicUpdateOp(parser, result, true);
1318 }
print(OpAsmPrinter & p)1319 void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
1320   ::printAtomicUpdateOp(*this, p);
1321 }
1322 
1323 //===----------------------------------------------------------------------===//
1324 // spv.AtomicFAddEXTOp
1325 //===----------------------------------------------------------------------===//
1326 
verify()1327 LogicalResult spirv::AtomicFAddEXTOp::verify() {
1328   return ::verifyAtomicUpdateOp<FloatType>(getOperation());
1329 }
1330 
parse(OpAsmParser & parser,OperationState & result)1331 ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser,
1332                                           OperationState &result) {
1333   return ::parseAtomicUpdateOp(parser, result, true);
1334 }
print(OpAsmPrinter & p)1335 void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) {
1336   ::printAtomicUpdateOp(*this, p);
1337 }
1338 
1339 //===----------------------------------------------------------------------===//
1340 // spv.AtomicIDecrementOp
1341 //===----------------------------------------------------------------------===//
1342 
verify()1343 LogicalResult spirv::AtomicIDecrementOp::verify() {
1344   return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1345 }
1346 
parse(OpAsmParser & parser,OperationState & result)1347 ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser,
1348                                              OperationState &result) {
1349   return ::parseAtomicUpdateOp(parser, result, false);
1350 }
print(OpAsmPrinter & p)1351 void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) {
1352   ::printAtomicUpdateOp(*this, p);
1353 }
1354 
1355 //===----------------------------------------------------------------------===//
1356 // spv.AtomicIIncrementOp
1357 //===----------------------------------------------------------------------===//
1358 
verify()1359 LogicalResult spirv::AtomicIIncrementOp::verify() {
1360   return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1361 }
1362 
parse(OpAsmParser & parser,OperationState & result)1363 ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser,
1364                                              OperationState &result) {
1365   return ::parseAtomicUpdateOp(parser, result, false);
1366 }
print(OpAsmPrinter & p)1367 void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) {
1368   ::printAtomicUpdateOp(*this, p);
1369 }
1370 
1371 //===----------------------------------------------------------------------===//
1372 // spv.AtomicISubOp
1373 //===----------------------------------------------------------------------===//
1374 
verify()1375 LogicalResult spirv::AtomicISubOp::verify() {
1376   return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1377 }
1378 
parse(OpAsmParser & parser,OperationState & result)1379 ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser,
1380                                        OperationState &result) {
1381   return ::parseAtomicUpdateOp(parser, result, true);
1382 }
print(OpAsmPrinter & p)1383 void spirv::AtomicISubOp::print(OpAsmPrinter &p) {
1384   ::printAtomicUpdateOp(*this, p);
1385 }
1386 
1387 //===----------------------------------------------------------------------===//
1388 // spv.AtomicOrOp
1389 //===----------------------------------------------------------------------===//
1390 
verify()1391 LogicalResult spirv::AtomicOrOp::verify() {
1392   return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1393 }
1394 
parse(OpAsmParser & parser,OperationState & result)1395 ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser,
1396                                      OperationState &result) {
1397   return ::parseAtomicUpdateOp(parser, result, true);
1398 }
print(OpAsmPrinter & p)1399 void spirv::AtomicOrOp::print(OpAsmPrinter &p) {
1400   ::printAtomicUpdateOp(*this, p);
1401 }
1402 
1403 //===----------------------------------------------------------------------===//
1404 // spv.AtomicSMaxOp
1405 //===----------------------------------------------------------------------===//
1406 
verify()1407 LogicalResult spirv::AtomicSMaxOp::verify() {
1408   return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1409 }
1410 
parse(OpAsmParser & parser,OperationState & result)1411 ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser,
1412                                        OperationState &result) {
1413   return ::parseAtomicUpdateOp(parser, result, true);
1414 }
print(OpAsmPrinter & p)1415 void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) {
1416   ::printAtomicUpdateOp(*this, p);
1417 }
1418 
1419 //===----------------------------------------------------------------------===//
1420 // spv.AtomicSMinOp
1421 //===----------------------------------------------------------------------===//
1422 
verify()1423 LogicalResult spirv::AtomicSMinOp::verify() {
1424   return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1425 }
1426 
parse(OpAsmParser & parser,OperationState & result)1427 ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser,
1428                                        OperationState &result) {
1429   return ::parseAtomicUpdateOp(parser, result, true);
1430 }
print(OpAsmPrinter & p)1431 void spirv::AtomicSMinOp::print(OpAsmPrinter &p) {
1432   ::printAtomicUpdateOp(*this, p);
1433 }
1434 
1435 //===----------------------------------------------------------------------===//
1436 // spv.AtomicUMaxOp
1437 //===----------------------------------------------------------------------===//
1438 
verify()1439 LogicalResult spirv::AtomicUMaxOp::verify() {
1440   return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1441 }
1442 
parse(OpAsmParser & parser,OperationState & result)1443 ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser,
1444                                        OperationState &result) {
1445   return ::parseAtomicUpdateOp(parser, result, true);
1446 }
print(OpAsmPrinter & p)1447 void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) {
1448   ::printAtomicUpdateOp(*this, p);
1449 }
1450 
1451 //===----------------------------------------------------------------------===//
1452 // spv.AtomicUMinOp
1453 //===----------------------------------------------------------------------===//
1454 
verify()1455 LogicalResult spirv::AtomicUMinOp::verify() {
1456   return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1457 }
1458 
parse(OpAsmParser & parser,OperationState & result)1459 ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser,
1460                                        OperationState &result) {
1461   return ::parseAtomicUpdateOp(parser, result, true);
1462 }
print(OpAsmPrinter & p)1463 void spirv::AtomicUMinOp::print(OpAsmPrinter &p) {
1464   ::printAtomicUpdateOp(*this, p);
1465 }
1466 
1467 //===----------------------------------------------------------------------===//
1468 // spv.AtomicXorOp
1469 //===----------------------------------------------------------------------===//
1470 
verify()1471 LogicalResult spirv::AtomicXorOp::verify() {
1472   return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1473 }
1474 
parse(OpAsmParser & parser,OperationState & result)1475 ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser,
1476                                       OperationState &result) {
1477   return ::parseAtomicUpdateOp(parser, result, true);
1478 }
print(OpAsmPrinter & p)1479 void spirv::AtomicXorOp::print(OpAsmPrinter &p) {
1480   ::printAtomicUpdateOp(*this, p);
1481 }
1482 
1483 //===----------------------------------------------------------------------===//
1484 // spv.BitcastOp
1485 //===----------------------------------------------------------------------===//
1486 
verify()1487 LogicalResult spirv::BitcastOp::verify() {
1488   // TODO: The SPIR-V spec validation rules are different for different
1489   // versions.
1490   auto operandType = operand().getType();
1491   auto resultType = result().getType();
1492   if (operandType == resultType) {
1493     return emitError("result type must be different from operand type");
1494   }
1495   if (operandType.isa<spirv::PointerType>() &&
1496       !resultType.isa<spirv::PointerType>()) {
1497     return emitError(
1498         "unhandled bit cast conversion from pointer type to non-pointer type");
1499   }
1500   if (!operandType.isa<spirv::PointerType>() &&
1501       resultType.isa<spirv::PointerType>()) {
1502     return emitError(
1503         "unhandled bit cast conversion from non-pointer type to pointer type");
1504   }
1505   auto operandBitWidth = getBitWidth(operandType);
1506   auto resultBitWidth = getBitWidth(resultType);
1507   if (operandBitWidth != resultBitWidth) {
1508     return emitOpError("mismatch in result type bitwidth ")
1509            << resultBitWidth << " and operand type bitwidth "
1510            << operandBitWidth;
1511   }
1512   return success();
1513 }
1514 
1515 //===----------------------------------------------------------------------===//
1516 // spv.BranchOp
1517 //===----------------------------------------------------------------------===//
1518 
getSuccessorOperands(unsigned index)1519 SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
1520   assert(index == 0 && "invalid successor index");
1521   return SuccessorOperands(0, targetOperandsMutable());
1522 }
1523 
1524 //===----------------------------------------------------------------------===//
1525 // spv.BranchConditionalOp
1526 //===----------------------------------------------------------------------===//
1527 
1528 SuccessorOperands
getSuccessorOperands(unsigned index)1529 spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
1530   assert(index < 2 && "invalid successor index");
1531   return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable()
1532                                                : falseTargetOperandsMutable());
1533 }
1534 
parse(OpAsmParser & parser,OperationState & state)1535 ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
1536                                               OperationState &state) {
1537   auto &builder = parser.getBuilder();
1538   OpAsmParser::UnresolvedOperand condInfo;
1539   Block *dest;
1540 
1541   // Parse the condition.
1542   Type boolTy = builder.getI1Type();
1543   if (parser.parseOperand(condInfo) ||
1544       parser.resolveOperand(condInfo, boolTy, state.operands))
1545     return failure();
1546 
1547   // Parse the optional branch weights.
1548   if (succeeded(parser.parseOptionalLSquare())) {
1549     IntegerAttr trueWeight, falseWeight;
1550     NamedAttrList weights;
1551 
1552     auto i32Type = builder.getIntegerType(32);
1553     if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1554         parser.parseComma() ||
1555         parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1556         parser.parseRSquare())
1557       return failure();
1558 
1559     state.addAttribute(kBranchWeightAttrName,
1560                        builder.getArrayAttr({trueWeight, falseWeight}));
1561   }
1562 
1563   // Parse the true branch.
1564   SmallVector<Value, 4> trueOperands;
1565   if (parser.parseComma() ||
1566       parser.parseSuccessorAndUseList(dest, trueOperands))
1567     return failure();
1568   state.addSuccessors(dest);
1569   state.addOperands(trueOperands);
1570 
1571   // Parse the false branch.
1572   SmallVector<Value, 4> falseOperands;
1573   if (parser.parseComma() ||
1574       parser.parseSuccessorAndUseList(dest, falseOperands))
1575     return failure();
1576   state.addSuccessors(dest);
1577   state.addOperands(falseOperands);
1578   state.addAttribute(
1579       spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1580       builder.getI32VectorAttr({1, static_cast<int32_t>(trueOperands.size()),
1581                                 static_cast<int32_t>(falseOperands.size())}));
1582 
1583   return success();
1584 }
1585 
print(OpAsmPrinter & printer)1586 void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
1587   printer << ' ' << condition();
1588 
1589   if (auto weights = branch_weights()) {
1590     printer << " [";
1591     llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1592       printer << a.cast<IntegerAttr>().getInt();
1593     });
1594     printer << "]";
1595   }
1596 
1597   printer << ", ";
1598   printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
1599   printer << ", ";
1600   printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
1601 }
1602 
verify()1603 LogicalResult spirv::BranchConditionalOp::verify() {
1604   if (auto weights = branch_weights()) {
1605     if (weights->getValue().size() != 2) {
1606       return emitOpError("must have exactly two branch weights");
1607     }
1608     if (llvm::all_of(*weights, [](Attribute attr) {
1609           return attr.cast<IntegerAttr>().getValue().isNullValue();
1610         }))
1611       return emitOpError("branch weights cannot both be zero");
1612   }
1613 
1614   return success();
1615 }
1616 
1617 //===----------------------------------------------------------------------===//
1618 // spv.CompositeConstruct
1619 //===----------------------------------------------------------------------===//
1620 
parse(OpAsmParser & parser,OperationState & state)1621 ParseResult spirv::CompositeConstructOp::parse(OpAsmParser &parser,
1622                                                OperationState &state) {
1623   SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
1624   Type type;
1625   auto loc = parser.getCurrentLocation();
1626 
1627   if (parser.parseOperandList(operands) || parser.parseColonType(type)) {
1628     return failure();
1629   }
1630   auto cType = type.dyn_cast<spirv::CompositeType>();
1631   if (!cType) {
1632     return parser.emitError(
1633                loc, "result type must be a composite type, but provided ")
1634            << type;
1635   }
1636 
1637   if (cType.hasCompileTimeKnownNumElements() &&
1638       operands.size() != cType.getNumElements()) {
1639     return parser.emitError(loc, "has incorrect number of operands: expected ")
1640            << cType.getNumElements() << ", but provided " << operands.size();
1641   }
1642   // TODO: Add support for constructing a vector type from the vector operands.
1643   // According to the spec: "for constructing a vector, the operands may
1644   // also be vectors with the same component type as the Result Type component
1645   // type".
1646   SmallVector<Type, 4> elementTypes;
1647   elementTypes.reserve(operands.size());
1648   for (auto index : llvm::seq<uint32_t>(0, operands.size())) {
1649     elementTypes.push_back(cType.getElementType(index));
1650   }
1651   state.addTypes(type);
1652   return parser.resolveOperands(operands, elementTypes, loc, state.operands);
1653 }
1654 
print(OpAsmPrinter & printer)1655 void spirv::CompositeConstructOp::print(OpAsmPrinter &printer) {
1656   printer << " " << constituents() << " : " << getResult().getType();
1657 }
1658 
verify()1659 LogicalResult spirv::CompositeConstructOp::verify() {
1660   auto cType = getType().cast<spirv::CompositeType>();
1661   operand_range constituents = this->constituents();
1662 
1663   if (cType.isa<spirv::CooperativeMatrixNVType>()) {
1664     if (constituents.size() != 1)
1665       return emitError("has incorrect number of operands: expected ")
1666              << "1, but provided " << constituents.size();
1667   } else if (constituents.size() != cType.getNumElements()) {
1668     return emitError("has incorrect number of operands: expected ")
1669            << cType.getNumElements() << ", but provided "
1670            << constituents.size();
1671   }
1672 
1673   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1674     if (constituents[index].getType() != cType.getElementType(index)) {
1675       return emitError("operand type mismatch: expected operand type ")
1676              << cType.getElementType(index) << ", but provided "
1677              << constituents[index].getType();
1678     }
1679   }
1680 
1681   return success();
1682 }
1683 
1684 //===----------------------------------------------------------------------===//
1685 // spv.CompositeExtractOp
1686 //===----------------------------------------------------------------------===//
1687 
build(OpBuilder & builder,OperationState & state,Value composite,ArrayRef<int32_t> indices)1688 void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1689                                       Value composite,
1690                                       ArrayRef<int32_t> indices) {
1691   auto indexAttr = builder.getI32ArrayAttr(indices);
1692   auto elementType =
1693       getElementType(composite.getType(), indexAttr, state.location);
1694   if (!elementType) {
1695     return;
1696   }
1697   build(builder, state, elementType, composite, indexAttr);
1698 }
1699 
parse(OpAsmParser & parser,OperationState & state)1700 ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
1701                                              OperationState &state) {
1702   OpAsmParser::UnresolvedOperand compositeInfo;
1703   Attribute indicesAttr;
1704   Type compositeType;
1705   SMLoc attrLocation;
1706 
1707   if (parser.parseOperand(compositeInfo) ||
1708       parser.getCurrentLocation(&attrLocation) ||
1709       parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1710       parser.parseColonType(compositeType) ||
1711       parser.resolveOperand(compositeInfo, compositeType, state.operands)) {
1712     return failure();
1713   }
1714 
1715   Type resultType =
1716       getElementType(compositeType, indicesAttr, parser, attrLocation);
1717   if (!resultType) {
1718     return failure();
1719   }
1720   state.addTypes(resultType);
1721   return success();
1722 }
1723 
print(OpAsmPrinter & printer)1724 void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
1725   printer << ' ' << composite() << indices() << " : " << composite().getType();
1726 }
1727 
verify()1728 LogicalResult spirv::CompositeExtractOp::verify() {
1729   auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
1730   auto resultType =
1731       getElementType(composite().getType(), indicesArrayAttr, getLoc());
1732   if (!resultType)
1733     return failure();
1734 
1735   if (resultType != getType()) {
1736     return emitOpError("invalid result type: expected ")
1737            << resultType << " but provided " << getType();
1738   }
1739 
1740   return success();
1741 }
1742 
1743 //===----------------------------------------------------------------------===//
1744 // spv.CompositeInsert
1745 //===----------------------------------------------------------------------===//
1746 
build(OpBuilder & builder,OperationState & state,Value object,Value composite,ArrayRef<int32_t> indices)1747 void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1748                                      Value object, Value composite,
1749                                      ArrayRef<int32_t> indices) {
1750   auto indexAttr = builder.getI32ArrayAttr(indices);
1751   build(builder, state, composite.getType(), object, composite, indexAttr);
1752 }
1753 
parse(OpAsmParser & parser,OperationState & state)1754 ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
1755                                             OperationState &state) {
1756   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
1757   Type objectType, compositeType;
1758   Attribute indicesAttr;
1759   auto loc = parser.getCurrentLocation();
1760 
1761   return failure(
1762       parser.parseOperandList(operands, 2) ||
1763       parser.parseAttribute(indicesAttr, kIndicesAttrName, state.attributes) ||
1764       parser.parseColonType(objectType) ||
1765       parser.parseKeywordType("into", compositeType) ||
1766       parser.resolveOperands(operands, {objectType, compositeType}, loc,
1767                              state.operands) ||
1768       parser.addTypesToList(compositeType, state.types));
1769 }
1770 
verify()1771 LogicalResult spirv::CompositeInsertOp::verify() {
1772   auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
1773   auto objectType =
1774       getElementType(composite().getType(), indicesArrayAttr, getLoc());
1775   if (!objectType)
1776     return failure();
1777 
1778   if (objectType != object().getType()) {
1779     return emitOpError("object operand type should be ")
1780            << objectType << ", but found " << object().getType();
1781   }
1782 
1783   if (composite().getType() != getType()) {
1784     return emitOpError("result type should be the same as "
1785                        "the composite type, but found ")
1786            << composite().getType() << " vs " << getType();
1787   }
1788 
1789   return success();
1790 }
1791 
print(OpAsmPrinter & printer)1792 void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
1793   printer << " " << object() << ", " << composite() << indices() << " : "
1794           << object().getType() << " into " << composite().getType();
1795 }
1796 
1797 //===----------------------------------------------------------------------===//
1798 // spv.Constant
1799 //===----------------------------------------------------------------------===//
1800 
parse(OpAsmParser & parser,OperationState & state)1801 ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
1802                                      OperationState &state) {
1803   Attribute value;
1804   if (parser.parseAttribute(value, kValueAttrName, state.attributes))
1805     return failure();
1806 
1807   Type type = value.getType();
1808   if (type.isa<NoneType, TensorType>()) {
1809     if (parser.parseColonType(type))
1810       return failure();
1811   }
1812 
1813   return parser.addTypeToList(type, state.types);
1814 }
1815 
print(OpAsmPrinter & printer)1816 void spirv::ConstantOp::print(OpAsmPrinter &printer) {
1817   printer << ' ' << value();
1818   if (getType().isa<spirv::ArrayType>())
1819     printer << " : " << getType();
1820 }
1821 
verifyConstantType(spirv::ConstantOp op,Attribute value,Type opType)1822 static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
1823                                         Type opType) {
1824   auto valueType = value.getType();
1825 
1826   if (value.isa<IntegerAttr, FloatAttr>()) {
1827     if (valueType != opType)
1828       return op.emitOpError("result type (")
1829              << opType << ") does not match value type (" << valueType << ")";
1830     return success();
1831   }
1832   if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1833     if (valueType == opType)
1834       return success();
1835     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1836     auto shapedType = valueType.dyn_cast<ShapedType>();
1837     if (!arrayType)
1838       return op.emitOpError("result or element type (")
1839              << opType << ") does not match value type (" << valueType
1840              << "), must be the same or spv.array";
1841 
1842     int numElements = arrayType.getNumElements();
1843     auto opElemType = arrayType.getElementType();
1844     while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1845       numElements *= t.getNumElements();
1846       opElemType = t.getElementType();
1847     }
1848     if (!opElemType.isIntOrFloat())
1849       return op.emitOpError("only support nested array result type");
1850 
1851     auto valueElemType = shapedType.getElementType();
1852     if (valueElemType != opElemType) {
1853       return op.emitOpError("result element type (")
1854              << opElemType << ") does not match value element type ("
1855              << valueElemType << ")";
1856     }
1857 
1858     if (numElements != shapedType.getNumElements()) {
1859       return op.emitOpError("result number of elements (")
1860              << numElements << ") does not match value number of elements ("
1861              << shapedType.getNumElements() << ")";
1862     }
1863     return success();
1864   }
1865   if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
1866     auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1867     if (!arrayType)
1868       return op.emitOpError("must have spv.array result type for array value");
1869     Type elemType = arrayType.getElementType();
1870     for (Attribute element : arrayAttr.getValue()) {
1871       // Verify array elements recursively.
1872       if (failed(verifyConstantType(op, element, elemType)))
1873         return failure();
1874     }
1875     return success();
1876   }
1877   return op.emitOpError("cannot have value of type ") << valueType;
1878 }
1879 
verify()1880 LogicalResult spirv::ConstantOp::verify() {
1881   // ODS already generates checks to make sure the result type is valid. We just
1882   // need to additionally check that the value's attribute type is consistent
1883   // with the result type.
1884   return verifyConstantType(*this, valueAttr(), getType());
1885 }
1886 
isBuildableWith(Type type)1887 bool spirv::ConstantOp::isBuildableWith(Type type) {
1888   // Must be valid SPIR-V type first.
1889   if (!type.isa<spirv::SPIRVType>())
1890     return false;
1891 
1892   if (isa<SPIRVDialect>(type.getDialect())) {
1893     // TODO: support constant struct
1894     return type.isa<spirv::ArrayType>();
1895   }
1896 
1897   return true;
1898 }
1899 
getZero(Type type,Location loc,OpBuilder & builder)1900 spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
1901                                              OpBuilder &builder) {
1902   if (auto intType = type.dyn_cast<IntegerType>()) {
1903     unsigned width = intType.getWidth();
1904     if (width == 1)
1905       return builder.create<spirv::ConstantOp>(loc, type,
1906                                                builder.getBoolAttr(false));
1907     return builder.create<spirv::ConstantOp>(
1908         loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
1909   }
1910   if (auto floatType = type.dyn_cast<FloatType>()) {
1911     return builder.create<spirv::ConstantOp>(
1912         loc, type, builder.getFloatAttr(floatType, 0.0));
1913   }
1914   if (auto vectorType = type.dyn_cast<VectorType>()) {
1915     Type elemType = vectorType.getElementType();
1916     if (elemType.isa<IntegerType>()) {
1917       return builder.create<spirv::ConstantOp>(
1918           loc, type,
1919           DenseElementsAttr::get(vectorType,
1920                                  IntegerAttr::get(elemType, 0.0).getValue()));
1921     }
1922     if (elemType.isa<FloatType>()) {
1923       return builder.create<spirv::ConstantOp>(
1924           loc, type,
1925           DenseFPElementsAttr::get(vectorType,
1926                                    FloatAttr::get(elemType, 0.0).getValue()));
1927     }
1928   }
1929 
1930   llvm_unreachable("unimplemented types for ConstantOp::getZero()");
1931 }
1932 
getOne(Type type,Location loc,OpBuilder & builder)1933 spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
1934                                             OpBuilder &builder) {
1935   if (auto intType = type.dyn_cast<IntegerType>()) {
1936     unsigned width = intType.getWidth();
1937     if (width == 1)
1938       return builder.create<spirv::ConstantOp>(loc, type,
1939                                                builder.getBoolAttr(true));
1940     return builder.create<spirv::ConstantOp>(
1941         loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
1942   }
1943   if (auto floatType = type.dyn_cast<FloatType>()) {
1944     return builder.create<spirv::ConstantOp>(
1945         loc, type, builder.getFloatAttr(floatType, 1.0));
1946   }
1947   if (auto vectorType = type.dyn_cast<VectorType>()) {
1948     Type elemType = vectorType.getElementType();
1949     if (elemType.isa<IntegerType>()) {
1950       return builder.create<spirv::ConstantOp>(
1951           loc, type,
1952           DenseElementsAttr::get(vectorType,
1953                                  IntegerAttr::get(elemType, 1.0).getValue()));
1954     }
1955     if (elemType.isa<FloatType>()) {
1956       return builder.create<spirv::ConstantOp>(
1957           loc, type,
1958           DenseFPElementsAttr::get(vectorType,
1959                                    FloatAttr::get(elemType, 1.0).getValue()));
1960     }
1961   }
1962 
1963   llvm_unreachable("unimplemented types for ConstantOp::getOne()");
1964 }
1965 
getAsmResultNames(llvm::function_ref<void (mlir::Value,llvm::StringRef)> setNameFn)1966 void mlir::spirv::ConstantOp::getAsmResultNames(
1967     llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
1968   Type type = getType();
1969 
1970   SmallString<32> specialNameBuffer;
1971   llvm::raw_svector_ostream specialName(specialNameBuffer);
1972   specialName << "cst";
1973 
1974   IntegerType intTy = type.dyn_cast<IntegerType>();
1975 
1976   if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
1977     if (intTy && intTy.getWidth() == 1) {
1978       return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
1979     }
1980 
1981     if (intTy.isSignless()) {
1982       specialName << intCst.getInt();
1983     } else {
1984       specialName << intCst.getSInt();
1985     }
1986   }
1987 
1988   if (intTy || type.isa<FloatType>()) {
1989     specialName << '_' << type;
1990   }
1991 
1992   if (auto vecType = type.dyn_cast<VectorType>()) {
1993     specialName << "_vec_";
1994     specialName << vecType.getDimSize(0);
1995 
1996     Type elementType = vecType.getElementType();
1997 
1998     if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
1999       specialName << "x" << elementType;
2000     }
2001   }
2002 
2003   setNameFn(getResult(), specialName.str());
2004 }
2005 
getAsmResultNames(llvm::function_ref<void (mlir::Value,llvm::StringRef)> setNameFn)2006 void mlir::spirv::AddressOfOp::getAsmResultNames(
2007     llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2008   SmallString<32> specialNameBuffer;
2009   llvm::raw_svector_ostream specialName(specialNameBuffer);
2010   specialName << variable() << "_addr";
2011   setNameFn(getResult(), specialName.str());
2012 }
2013 
2014 //===----------------------------------------------------------------------===//
2015 // spv.ControlBarrierOp
2016 //===----------------------------------------------------------------------===//
2017 
verify()2018 LogicalResult spirv::ControlBarrierOp::verify() {
2019   return verifyMemorySemantics(getOperation(), memory_semantics());
2020 }
2021 
2022 //===----------------------------------------------------------------------===//
2023 // spv.ConvertFToSOp
2024 //===----------------------------------------------------------------------===//
2025 
verify()2026 LogicalResult spirv::ConvertFToSOp::verify() {
2027   return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2028                       /*skipBitWidthCheck=*/true);
2029 }
2030 
2031 //===----------------------------------------------------------------------===//
2032 // spv.ConvertFToUOp
2033 //===----------------------------------------------------------------------===//
2034 
verify()2035 LogicalResult spirv::ConvertFToUOp::verify() {
2036   return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2037                       /*skipBitWidthCheck=*/true);
2038 }
2039 
2040 //===----------------------------------------------------------------------===//
2041 // spv.ConvertSToFOp
2042 //===----------------------------------------------------------------------===//
2043 
verify()2044 LogicalResult spirv::ConvertSToFOp::verify() {
2045   return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2046                       /*skipBitWidthCheck=*/true);
2047 }
2048 
2049 //===----------------------------------------------------------------------===//
2050 // spv.ConvertUToFOp
2051 //===----------------------------------------------------------------------===//
2052 
verify()2053 LogicalResult spirv::ConvertUToFOp::verify() {
2054   return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2055                       /*skipBitWidthCheck=*/true);
2056 }
2057 
2058 //===----------------------------------------------------------------------===//
2059 // spv.EntryPoint
2060 //===----------------------------------------------------------------------===//
2061 
build(OpBuilder & builder,OperationState & state,spirv::ExecutionModel executionModel,spirv::FuncOp function,ArrayRef<Attribute> interfaceVars)2062 void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
2063                                 spirv::ExecutionModel executionModel,
2064                                 spirv::FuncOp function,
2065                                 ArrayRef<Attribute> interfaceVars) {
2066   build(builder, state,
2067         spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
2068         SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
2069 }
2070 
parse(OpAsmParser & parser,OperationState & state)2071 ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
2072                                        OperationState &state) {
2073   spirv::ExecutionModel execModel;
2074   SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers;
2075   SmallVector<Type, 0> idTypes;
2076   SmallVector<Attribute, 4> interfaceVars;
2077 
2078   FlatSymbolRefAttr fn;
2079   if (parseEnumStrAttr(execModel, parser, state) ||
2080       parser.parseAttribute(fn, Type(), kFnNameAttrName, state.attributes)) {
2081     return failure();
2082   }
2083 
2084   if (!parser.parseOptionalComma()) {
2085     // Parse the interface variables
2086     if (parser.parseCommaSeparatedList([&]() -> ParseResult {
2087           // The name of the interface variable attribute isnt important
2088           FlatSymbolRefAttr var;
2089           NamedAttrList attrs;
2090           if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
2091             return failure();
2092           interfaceVars.push_back(var);
2093           return success();
2094         }))
2095       return failure();
2096   }
2097   state.addAttribute(kInterfaceAttrName,
2098                      parser.getBuilder().getArrayAttr(interfaceVars));
2099   return success();
2100 }
2101 
print(OpAsmPrinter & printer)2102 void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
2103   printer << " \"" << stringifyExecutionModel(execution_model()) << "\" ";
2104   printer.printSymbolName(fn());
2105   auto interfaceVars = interface().getValue();
2106   if (!interfaceVars.empty()) {
2107     printer << ", ";
2108     llvm::interleaveComma(interfaceVars, printer);
2109   }
2110 }
2111 
verify()2112 LogicalResult spirv::EntryPointOp::verify() {
2113   // Checks for fn and interface symbol reference are done in spirv::ModuleOp
2114   // verification.
2115   return success();
2116 }
2117 
2118 //===----------------------------------------------------------------------===//
2119 // spv.ExecutionMode
2120 //===----------------------------------------------------------------------===//
2121 
build(OpBuilder & builder,OperationState & state,spirv::FuncOp function,spirv::ExecutionMode executionMode,ArrayRef<int32_t> params)2122 void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
2123                                    spirv::FuncOp function,
2124                                    spirv::ExecutionMode executionMode,
2125                                    ArrayRef<int32_t> params) {
2126   build(builder, state, SymbolRefAttr::get(function),
2127         spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
2128         builder.getI32ArrayAttr(params));
2129 }
2130 
parse(OpAsmParser & parser,OperationState & state)2131 ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
2132                                           OperationState &state) {
2133   spirv::ExecutionMode execMode;
2134   Attribute fn;
2135   if (parser.parseAttribute(fn, kFnNameAttrName, state.attributes) ||
2136       parseEnumStrAttr(execMode, parser, state)) {
2137     return failure();
2138   }
2139 
2140   SmallVector<int32_t, 4> values;
2141   Type i32Type = parser.getBuilder().getIntegerType(32);
2142   while (!parser.parseOptionalComma()) {
2143     NamedAttrList attr;
2144     Attribute value;
2145     if (parser.parseAttribute(value, i32Type, "value", attr)) {
2146       return failure();
2147     }
2148     values.push_back(value.cast<IntegerAttr>().getInt());
2149   }
2150   state.addAttribute(kValuesAttrName,
2151                      parser.getBuilder().getI32ArrayAttr(values));
2152   return success();
2153 }
2154 
print(OpAsmPrinter & printer)2155 void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
2156   printer << " ";
2157   printer.printSymbolName(fn());
2158   printer << " \"" << stringifyExecutionMode(execution_mode()) << "\"";
2159   auto values = this->values();
2160   if (values.empty())
2161     return;
2162   printer << ", ";
2163   llvm::interleaveComma(values, printer, [&](Attribute a) {
2164     printer << a.cast<IntegerAttr>().getInt();
2165   });
2166 }
2167 
2168 //===----------------------------------------------------------------------===//
2169 // spv.FConvertOp
2170 //===----------------------------------------------------------------------===//
2171 
verify()2172 LogicalResult spirv::FConvertOp::verify() {
2173   return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2174 }
2175 
2176 //===----------------------------------------------------------------------===//
2177 // spv.SConvertOp
2178 //===----------------------------------------------------------------------===//
2179 
verify()2180 LogicalResult spirv::SConvertOp::verify() {
2181   return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2182 }
2183 
2184 //===----------------------------------------------------------------------===//
2185 // spv.UConvertOp
2186 //===----------------------------------------------------------------------===//
2187 
verify()2188 LogicalResult spirv::UConvertOp::verify() {
2189   return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2190 }
2191 
2192 //===----------------------------------------------------------------------===//
2193 // spv.func
2194 //===----------------------------------------------------------------------===//
2195 
parse(OpAsmParser & parser,OperationState & state)2196 ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &state) {
2197   SmallVector<OpAsmParser::Argument> entryArgs;
2198   SmallVector<DictionaryAttr> resultAttrs;
2199   SmallVector<Type> resultTypes;
2200   auto &builder = parser.getBuilder();
2201 
2202   // Parse the name as a symbol.
2203   StringAttr nameAttr;
2204   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2205                              state.attributes))
2206     return failure();
2207 
2208   // Parse the function signature.
2209   bool isVariadic = false;
2210   if (function_interface_impl::parseFunctionSignature(
2211           parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
2212           resultAttrs))
2213     return failure();
2214 
2215   SmallVector<Type> argTypes;
2216   for (auto &arg : entryArgs)
2217     argTypes.push_back(arg.type);
2218   auto fnType = builder.getFunctionType(argTypes, resultTypes);
2219   state.addAttribute(FunctionOpInterface::getTypeAttrName(),
2220                      TypeAttr::get(fnType));
2221 
2222   // Parse the optional function control keyword.
2223   spirv::FunctionControl fnControl;
2224   if (parseEnumStrAttr(fnControl, parser, state))
2225     return failure();
2226 
2227   // If additional attributes are present, parse them.
2228   if (parser.parseOptionalAttrDictWithKeyword(state.attributes))
2229     return failure();
2230 
2231   // Add the attributes to the function arguments.
2232   assert(resultAttrs.size() == resultTypes.size());
2233   function_interface_impl::addArgAndResultAttrs(builder, state, entryArgs,
2234                                                 resultAttrs);
2235 
2236   // Parse the optional function body.
2237   auto *body = state.addRegion();
2238   OptionalParseResult result = parser.parseOptionalRegion(*body, entryArgs);
2239   return failure(result.hasValue() && failed(*result));
2240 }
2241 
print(OpAsmPrinter & printer)2242 void spirv::FuncOp::print(OpAsmPrinter &printer) {
2243   // Print function name, signature, and control.
2244   printer << " ";
2245   printer.printSymbolName(sym_name());
2246   auto fnType = getFunctionType();
2247   function_interface_impl::printFunctionSignature(
2248       printer, *this, fnType.getInputs(),
2249       /*isVariadic=*/false, fnType.getResults());
2250   printer << " \"" << spirv::stringifyFunctionControl(function_control())
2251           << "\"";
2252   function_interface_impl::printFunctionAttributes(
2253       printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
2254       {spirv::attributeName<spirv::FunctionControl>()});
2255 
2256   // Print the body if this is not an external function.
2257   Region &body = this->body();
2258   if (!body.empty()) {
2259     printer << ' ';
2260     printer.printRegion(body, /*printEntryBlockArgs=*/false,
2261                         /*printBlockTerminators=*/true);
2262   }
2263 }
2264 
verifyType()2265 LogicalResult spirv::FuncOp::verifyType() {
2266   auto type = getFunctionTypeAttr().getValue();
2267   if (!type.isa<FunctionType>())
2268     return emitOpError("requires '" + getTypeAttrName() +
2269                        "' attribute of function type");
2270   if (getFunctionType().getNumResults() > 1)
2271     return emitOpError("cannot have more than one result");
2272   return success();
2273 }
2274 
verifyBody()2275 LogicalResult spirv::FuncOp::verifyBody() {
2276   FunctionType fnType = getFunctionType();
2277 
2278   auto walkResult = walk([fnType](Operation *op) -> WalkResult {
2279     if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
2280       if (fnType.getNumResults() != 0)
2281         return retOp.emitOpError("cannot be used in functions returning value");
2282     } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
2283       if (fnType.getNumResults() != 1)
2284         return retOp.emitOpError(
2285                    "returns 1 value but enclosing function requires ")
2286                << fnType.getNumResults() << " results";
2287 
2288       auto retOperandType = retOp.value().getType();
2289       auto fnResultType = fnType.getResult(0);
2290       if (retOperandType != fnResultType)
2291         return retOp.emitOpError(" return value's type (")
2292                << retOperandType << ") mismatch with function's result type ("
2293                << fnResultType << ")";
2294     }
2295     return WalkResult::advance();
2296   });
2297 
2298   // TODO: verify other bits like linkage type.
2299 
2300   return failure(walkResult.wasInterrupted());
2301 }
2302 
build(OpBuilder & builder,OperationState & state,StringRef name,FunctionType type,spirv::FunctionControl control,ArrayRef<NamedAttribute> attrs)2303 void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
2304                           StringRef name, FunctionType type,
2305                           spirv::FunctionControl control,
2306                           ArrayRef<NamedAttribute> attrs) {
2307   state.addAttribute(SymbolTable::getSymbolAttrName(),
2308                      builder.getStringAttr(name));
2309   state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
2310   state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
2311                      builder.getI32IntegerAttr(static_cast<uint32_t>(control)));
2312   state.attributes.append(attrs.begin(), attrs.end());
2313   state.addRegion();
2314 }
2315 
2316 // CallableOpInterface
getCallableRegion()2317 Region *spirv::FuncOp::getCallableRegion() {
2318   return isExternal() ? nullptr : &body();
2319 }
2320 
2321 // CallableOpInterface
getCallableResults()2322 ArrayRef<Type> spirv::FuncOp::getCallableResults() {
2323   return getFunctionType().getResults();
2324 }
2325 
2326 //===----------------------------------------------------------------------===//
2327 // spv.FunctionCall
2328 //===----------------------------------------------------------------------===//
2329 
verify()2330 LogicalResult spirv::FunctionCallOp::verify() {
2331   auto fnName = calleeAttr();
2332 
2333   auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
2334       SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
2335   if (!funcOp) {
2336     return emitOpError("callee function '")
2337            << fnName.getValue() << "' not found in nearest symbol table";
2338   }
2339 
2340   auto functionType = funcOp.getFunctionType();
2341 
2342   if (getNumResults() > 1) {
2343     return emitOpError(
2344                "expected callee function to have 0 or 1 result, but provided ")
2345            << getNumResults();
2346   }
2347 
2348   if (functionType.getNumInputs() != getNumOperands()) {
2349     return emitOpError("has incorrect number of operands for callee: expected ")
2350            << functionType.getNumInputs() << ", but provided "
2351            << getNumOperands();
2352   }
2353 
2354   for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
2355     if (getOperand(i).getType() != functionType.getInput(i)) {
2356       return emitOpError("operand type mismatch: expected operand type ")
2357              << functionType.getInput(i) << ", but provided "
2358              << getOperand(i).getType() << " for operand number " << i;
2359     }
2360   }
2361 
2362   if (functionType.getNumResults() != getNumResults()) {
2363     return emitOpError(
2364                "has incorrect number of results has for callee: expected ")
2365            << functionType.getNumResults() << ", but provided "
2366            << getNumResults();
2367   }
2368 
2369   if (getNumResults() &&
2370       (getResult(0).getType() != functionType.getResult(0))) {
2371     return emitOpError("result type mismatch: expected ")
2372            << functionType.getResult(0) << ", but provided "
2373            << getResult(0).getType();
2374   }
2375 
2376   return success();
2377 }
2378 
getCallableForCallee()2379 CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2380   return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2381 }
2382 
getArgOperands()2383 Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2384   return arguments();
2385 }
2386 
2387 //===----------------------------------------------------------------------===//
2388 // spv.GLFClampOp
2389 //===----------------------------------------------------------------------===//
2390 
parse(OpAsmParser & parser,OperationState & result)2391 ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
2392                                      OperationState &result) {
2393   return parseOneResultSameOperandTypeOp(parser, result);
2394 }
print(OpAsmPrinter & p)2395 void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2396 
2397 //===----------------------------------------------------------------------===//
2398 // spv.GLUClampOp
2399 //===----------------------------------------------------------------------===//
2400 
parse(OpAsmParser & parser,OperationState & result)2401 ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
2402                                      OperationState &result) {
2403   return parseOneResultSameOperandTypeOp(parser, result);
2404 }
print(OpAsmPrinter & p)2405 void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2406 
2407 //===----------------------------------------------------------------------===//
2408 // spv.GLSClampOp
2409 //===----------------------------------------------------------------------===//
2410 
parse(OpAsmParser & parser,OperationState & result)2411 ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
2412                                      OperationState &result) {
2413   return parseOneResultSameOperandTypeOp(parser, result);
2414 }
print(OpAsmPrinter & p)2415 void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2416 
2417 //===----------------------------------------------------------------------===//
2418 // spv.GLFmaOp
2419 //===----------------------------------------------------------------------===//
2420 
parse(OpAsmParser & parser,OperationState & result)2421 ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
2422   return parseOneResultSameOperandTypeOp(parser, result);
2423 }
print(OpAsmPrinter & p)2424 void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2425 
2426 //===----------------------------------------------------------------------===//
2427 // spv.GlobalVariable
2428 //===----------------------------------------------------------------------===//
2429 
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,unsigned descriptorSet,unsigned binding)2430 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2431                                     Type type, StringRef name,
2432                                     unsigned descriptorSet, unsigned binding) {
2433   build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2434   state.addAttribute(
2435       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2436       builder.getI32IntegerAttr(descriptorSet));
2437   state.addAttribute(
2438       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2439       builder.getI32IntegerAttr(binding));
2440 }
2441 
build(OpBuilder & builder,OperationState & state,Type type,StringRef name,spirv::BuiltIn builtin)2442 void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2443                                     Type type, StringRef name,
2444                                     spirv::BuiltIn builtin) {
2445   build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2446   state.addAttribute(
2447       spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2448       builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2449 }
2450 
parse(OpAsmParser & parser,OperationState & state)2451 ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
2452                                            OperationState &state) {
2453   // Parse variable name.
2454   StringAttr nameAttr;
2455   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2456                              state.attributes)) {
2457     return failure();
2458   }
2459 
2460   // Parse optional initializer
2461   if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
2462     FlatSymbolRefAttr initSymbol;
2463     if (parser.parseLParen() ||
2464         parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2465                               state.attributes) ||
2466         parser.parseRParen())
2467       return failure();
2468   }
2469 
2470   if (parseVariableDecorations(parser, state)) {
2471     return failure();
2472   }
2473 
2474   Type type;
2475   auto loc = parser.getCurrentLocation();
2476   if (parser.parseColonType(type)) {
2477     return failure();
2478   }
2479   if (!type.isa<spirv::PointerType>()) {
2480     return parser.emitError(loc, "expected spv.ptr type");
2481   }
2482   state.addAttribute(kTypeAttrName, TypeAttr::get(type));
2483 
2484   return success();
2485 }
2486 
print(OpAsmPrinter & printer)2487 void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
2488   SmallVector<StringRef, 4> elidedAttrs{
2489       spirv::attributeName<spirv::StorageClass>()};
2490 
2491   // Print variable name.
2492   printer << ' ';
2493   printer.printSymbolName(sym_name());
2494   elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2495 
2496   // Print optional initializer
2497   if (auto initializer = this->initializer()) {
2498     printer << " " << kInitializerAttrName << '(';
2499     printer.printSymbolName(*initializer);
2500     printer << ')';
2501     elidedAttrs.push_back(kInitializerAttrName);
2502   }
2503 
2504   elidedAttrs.push_back(kTypeAttrName);
2505   printVariableDecorations(*this, printer, elidedAttrs);
2506   printer << " : " << type();
2507 }
2508 
verify()2509 LogicalResult spirv::GlobalVariableOp::verify() {
2510   // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2511   // object. It cannot be Generic. It must be the same as the Storage Class
2512   // operand of the Result Type."
2513   // Also, Function storage class is reserved by spv.Variable.
2514   auto storageClass = this->storageClass();
2515   if (storageClass == spirv::StorageClass::Generic ||
2516       storageClass == spirv::StorageClass::Function) {
2517     return emitOpError("storage class cannot be '")
2518            << stringifyStorageClass(storageClass) << "'";
2519   }
2520 
2521   if (auto init =
2522           (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2523     Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
2524         (*this)->getParentOp(), init.getAttr());
2525     // TODO: Currently only variable initialization with specialization
2526     // constants and other variables is supported. They could be normal
2527     // constants in the module scope as well.
2528     if (!initOp ||
2529         !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2530       return emitOpError("initializer must be result of a "
2531                          "spv.SpecConstant or spv.GlobalVariable op");
2532     }
2533   }
2534 
2535   return success();
2536 }
2537 
2538 //===----------------------------------------------------------------------===//
2539 // spv.GroupBroadcast
2540 //===----------------------------------------------------------------------===//
2541 
verify()2542 LogicalResult spirv::GroupBroadcastOp::verify() {
2543   spirv::Scope scope = execution_scope();
2544   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2545     return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2546 
2547   if (auto localIdTy = localid().getType().dyn_cast<VectorType>())
2548     if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
2549       return emitOpError("localid is a vector and can be with only "
2550                          " 2 or 3 components, actual number is ")
2551              << localIdTy.getNumElements();
2552 
2553   return success();
2554 }
2555 
2556 //===----------------------------------------------------------------------===//
2557 // spv.GroupNonUniformBallotOp
2558 //===----------------------------------------------------------------------===//
2559 
verify()2560 LogicalResult spirv::GroupNonUniformBallotOp::verify() {
2561   spirv::Scope scope = execution_scope();
2562   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2563     return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2564 
2565   return success();
2566 }
2567 
2568 //===----------------------------------------------------------------------===//
2569 // spv.GroupNonUniformBroadcast
2570 //===----------------------------------------------------------------------===//
2571 
verify()2572 LogicalResult spirv::GroupNonUniformBroadcastOp::verify() {
2573   spirv::Scope scope = execution_scope();
2574   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2575     return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2576 
2577   // SPIR-V spec: "Before version 1.5, Id must come from a
2578   // constant instruction.
2579   auto targetEnv = spirv::getDefaultTargetEnv(getContext());
2580   if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
2581     targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2582 
2583   if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2584     auto *idOp = id().getDefiningOp();
2585     if (!idOp || !isa<spirv::ConstantOp,           // for normal constant
2586                       spirv::ReferenceOfOp>(idOp)) // for spec constant
2587       return emitOpError("id must be the result of a constant op");
2588   }
2589 
2590   return success();
2591 }
2592 
2593 //===----------------------------------------------------------------------===//
2594 // spv.SubgroupBlockReadINTEL
2595 //===----------------------------------------------------------------------===//
2596 
parse(OpAsmParser & parser,OperationState & state)2597 ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
2598                                                    OperationState &state) {
2599   // Parse the storage class specification
2600   spirv::StorageClass storageClass;
2601   OpAsmParser::UnresolvedOperand ptrInfo;
2602   Type elementType;
2603   if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2604       parser.parseColon() || parser.parseType(elementType)) {
2605     return failure();
2606   }
2607 
2608   auto ptrType = spirv::PointerType::get(elementType, storageClass);
2609   if (auto valVecTy = elementType.dyn_cast<VectorType>())
2610     ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2611 
2612   if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2613     return failure();
2614   }
2615 
2616   state.addTypes(elementType);
2617   return success();
2618 }
2619 
print(OpAsmPrinter & printer)2620 void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) {
2621   printer << " " << ptr() << " : " << getType();
2622 }
2623 
verify()2624 LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
2625   if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
2626     return failure();
2627 
2628   return success();
2629 }
2630 
2631 //===----------------------------------------------------------------------===//
2632 // spv.SubgroupBlockWriteINTEL
2633 //===----------------------------------------------------------------------===//
2634 
parse(OpAsmParser & parser,OperationState & state)2635 ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
2636                                                     OperationState &state) {
2637   // Parse the storage class specification
2638   spirv::StorageClass storageClass;
2639   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
2640   auto loc = parser.getCurrentLocation();
2641   Type elementType;
2642   if (parseEnumStrAttr(storageClass, parser) ||
2643       parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2644       parser.parseType(elementType)) {
2645     return failure();
2646   }
2647 
2648   auto ptrType = spirv::PointerType::get(elementType, storageClass);
2649   if (auto valVecTy = elementType.dyn_cast<VectorType>())
2650     ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2651 
2652   if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2653                              state.operands)) {
2654     return failure();
2655   }
2656   return success();
2657 }
2658 
print(OpAsmPrinter & printer)2659 void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) {
2660   printer << " " << ptr() << ", " << value() << " : " << value().getType();
2661 }
2662 
verify()2663 LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() {
2664   if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
2665     return failure();
2666 
2667   return success();
2668 }
2669 
2670 //===----------------------------------------------------------------------===//
2671 // spv.GroupNonUniformElectOp
2672 //===----------------------------------------------------------------------===//
2673 
verify()2674 LogicalResult spirv::GroupNonUniformElectOp::verify() {
2675   spirv::Scope scope = execution_scope();
2676   if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2677     return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2678 
2679   return success();
2680 }
2681 
2682 //===----------------------------------------------------------------------===//
2683 // spv.GroupNonUniformFAddOp
2684 //===----------------------------------------------------------------------===//
2685 
verify()2686 LogicalResult spirv::GroupNonUniformFAddOp::verify() {
2687   return verifyGroupNonUniformArithmeticOp(*this);
2688 }
2689 
parse(OpAsmParser & parser,OperationState & result)2690 ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser,
2691                                                 OperationState &result) {
2692   return parseGroupNonUniformArithmeticOp(parser, result);
2693 }
print(OpAsmPrinter & p)2694 void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
2695   printGroupNonUniformArithmeticOp(*this, p);
2696 }
2697 
2698 //===----------------------------------------------------------------------===//
2699 // spv.GroupNonUniformFMaxOp
2700 //===----------------------------------------------------------------------===//
2701 
verify()2702 LogicalResult spirv::GroupNonUniformFMaxOp::verify() {
2703   return verifyGroupNonUniformArithmeticOp(*this);
2704 }
2705 
parse(OpAsmParser & parser,OperationState & result)2706 ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
2707                                                 OperationState &result) {
2708   return parseGroupNonUniformArithmeticOp(parser, result);
2709 }
print(OpAsmPrinter & p)2710 void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
2711   printGroupNonUniformArithmeticOp(*this, p);
2712 }
2713 
2714 //===----------------------------------------------------------------------===//
2715 // spv.GroupNonUniformFMinOp
2716 //===----------------------------------------------------------------------===//
2717 
verify()2718 LogicalResult spirv::GroupNonUniformFMinOp::verify() {
2719   return verifyGroupNonUniformArithmeticOp(*this);
2720 }
2721 
parse(OpAsmParser & parser,OperationState & result)2722 ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser,
2723                                                 OperationState &result) {
2724   return parseGroupNonUniformArithmeticOp(parser, result);
2725 }
print(OpAsmPrinter & p)2726 void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
2727   printGroupNonUniformArithmeticOp(*this, p);
2728 }
2729 
2730 //===----------------------------------------------------------------------===//
2731 // spv.GroupNonUniformFMulOp
2732 //===----------------------------------------------------------------------===//
2733 
verify()2734 LogicalResult spirv::GroupNonUniformFMulOp::verify() {
2735   return verifyGroupNonUniformArithmeticOp(*this);
2736 }
2737 
parse(OpAsmParser & parser,OperationState & result)2738 ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser,
2739                                                 OperationState &result) {
2740   return parseGroupNonUniformArithmeticOp(parser, result);
2741 }
print(OpAsmPrinter & p)2742 void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
2743   printGroupNonUniformArithmeticOp(*this, p);
2744 }
2745 
2746 //===----------------------------------------------------------------------===//
2747 // spv.GroupNonUniformIAddOp
2748 //===----------------------------------------------------------------------===//
2749 
verify()2750 LogicalResult spirv::GroupNonUniformIAddOp::verify() {
2751   return verifyGroupNonUniformArithmeticOp(*this);
2752 }
2753 
parse(OpAsmParser & parser,OperationState & result)2754 ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser,
2755                                                 OperationState &result) {
2756   return parseGroupNonUniformArithmeticOp(parser, result);
2757 }
print(OpAsmPrinter & p)2758 void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
2759   printGroupNonUniformArithmeticOp(*this, p);
2760 }
2761 
2762 //===----------------------------------------------------------------------===//
2763 // spv.GroupNonUniformIMulOp
2764 //===----------------------------------------------------------------------===//
2765 
verify()2766 LogicalResult spirv::GroupNonUniformIMulOp::verify() {
2767   return verifyGroupNonUniformArithmeticOp(*this);
2768 }
2769 
parse(OpAsmParser & parser,OperationState & result)2770 ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser,
2771                                                 OperationState &result) {
2772   return parseGroupNonUniformArithmeticOp(parser, result);
2773 }
print(OpAsmPrinter & p)2774 void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
2775   printGroupNonUniformArithmeticOp(*this, p);
2776 }
2777 
2778 //===----------------------------------------------------------------------===//
2779 // spv.GroupNonUniformSMaxOp
2780 //===----------------------------------------------------------------------===//
2781 
verify()2782 LogicalResult spirv::GroupNonUniformSMaxOp::verify() {
2783   return verifyGroupNonUniformArithmeticOp(*this);
2784 }
2785 
parse(OpAsmParser & parser,OperationState & result)2786 ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
2787                                                 OperationState &result) {
2788   return parseGroupNonUniformArithmeticOp(parser, result);
2789 }
print(OpAsmPrinter & p)2790 void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
2791   printGroupNonUniformArithmeticOp(*this, p);
2792 }
2793 
2794 //===----------------------------------------------------------------------===//
2795 // spv.GroupNonUniformSMinOp
2796 //===----------------------------------------------------------------------===//
2797 
verify()2798 LogicalResult spirv::GroupNonUniformSMinOp::verify() {
2799   return verifyGroupNonUniformArithmeticOp(*this);
2800 }
2801 
parse(OpAsmParser & parser,OperationState & result)2802 ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser,
2803                                                 OperationState &result) {
2804   return parseGroupNonUniformArithmeticOp(parser, result);
2805 }
print(OpAsmPrinter & p)2806 void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
2807   printGroupNonUniformArithmeticOp(*this, p);
2808 }
2809 
2810 //===----------------------------------------------------------------------===//
2811 // spv.GroupNonUniformUMaxOp
2812 //===----------------------------------------------------------------------===//
2813 
verify()2814 LogicalResult spirv::GroupNonUniformUMaxOp::verify() {
2815   return verifyGroupNonUniformArithmeticOp(*this);
2816 }
2817 
parse(OpAsmParser & parser,OperationState & result)2818 ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
2819                                                 OperationState &result) {
2820   return parseGroupNonUniformArithmeticOp(parser, result);
2821 }
print(OpAsmPrinter & p)2822 void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
2823   printGroupNonUniformArithmeticOp(*this, p);
2824 }
2825 
2826 //===----------------------------------------------------------------------===//
2827 // spv.GroupNonUniformUMinOp
2828 //===----------------------------------------------------------------------===//
2829 
verify()2830 LogicalResult spirv::GroupNonUniformUMinOp::verify() {
2831   return verifyGroupNonUniformArithmeticOp(*this);
2832 }
2833 
parse(OpAsmParser & parser,OperationState & result)2834 ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser,
2835                                                 OperationState &result) {
2836   return parseGroupNonUniformArithmeticOp(parser, result);
2837 }
print(OpAsmPrinter & p)2838 void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
2839   printGroupNonUniformArithmeticOp(*this, p);
2840 }
2841 
2842 //===----------------------------------------------------------------------===//
2843 // spv.ISubBorrowOp
2844 //===----------------------------------------------------------------------===//
2845 
verify()2846 LogicalResult spirv::ISubBorrowOp::verify() {
2847   auto resultType = getType().cast<spirv::StructType>();
2848   if (resultType.getNumElements() != 2)
2849     return emitOpError("expected result struct type containing two members");
2850 
2851   SmallVector<Type, 4> types;
2852   types.push_back(operand1().getType());
2853   types.push_back(operand2().getType());
2854   types.push_back(resultType.getElementType(0));
2855   types.push_back(resultType.getElementType(1));
2856   if (!llvm::is_splat(types))
2857     return emitOpError(
2858         "expected all operand types and struct member types are the same");
2859 
2860   return success();
2861 }
2862 
parse(OpAsmParser & parser,OperationState & state)2863 ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
2864                                        OperationState &state) {
2865   SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
2866   if (parser.parseOptionalAttrDict(state.attributes) ||
2867       parser.parseOperandList(operands) || parser.parseColon())
2868     return failure();
2869 
2870   Type resultType;
2871   auto loc = parser.getCurrentLocation();
2872   if (parser.parseType(resultType))
2873     return failure();
2874 
2875   auto structType = resultType.dyn_cast<spirv::StructType>();
2876   if (!structType || structType.getNumElements() != 2)
2877     return parser.emitError(loc, "expected spv.struct type with two members");
2878 
2879   SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
2880   if (parser.resolveOperands(operands, operandTypes, loc, state.operands))
2881     return failure();
2882 
2883   state.addTypes(resultType);
2884   return success();
2885 }
2886 
print(OpAsmPrinter & printer)2887 void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
2888   printer << ' ';
2889   printer.printOptionalAttrDict((*this)->getAttrs());
2890   printer.printOperands((*this)->getOperands());
2891   printer << " : " << getType();
2892 }
2893 
2894 //===----------------------------------------------------------------------===//
2895 // spv.LoadOp
2896 //===----------------------------------------------------------------------===//
2897 
build(OpBuilder & builder,OperationState & state,Value basePtr,MemoryAccessAttr memoryAccess,IntegerAttr alignment)2898 void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
2899                           Value basePtr, MemoryAccessAttr memoryAccess,
2900                           IntegerAttr alignment) {
2901   auto ptrType = basePtr.getType().cast<spirv::PointerType>();
2902   build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
2903         alignment);
2904 }
2905 
parse(OpAsmParser & parser,OperationState & state)2906 ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &state) {
2907   // Parse the storage class specification
2908   spirv::StorageClass storageClass;
2909   OpAsmParser::UnresolvedOperand ptrInfo;
2910   Type elementType;
2911   if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2912       parseMemoryAccessAttributes(parser, state) ||
2913       parser.parseOptionalAttrDict(state.attributes) || parser.parseColon() ||
2914       parser.parseType(elementType)) {
2915     return failure();
2916   }
2917 
2918   auto ptrType = spirv::PointerType::get(elementType, storageClass);
2919   if (parser.resolveOperand(ptrInfo, ptrType, state.operands)) {
2920     return failure();
2921   }
2922 
2923   state.addTypes(elementType);
2924   return success();
2925 }
2926 
print(OpAsmPrinter & printer)2927 void spirv::LoadOp::print(OpAsmPrinter &printer) {
2928   SmallVector<StringRef, 4> elidedAttrs;
2929   StringRef sc = stringifyStorageClass(
2930       ptr().getType().cast<spirv::PointerType>().getStorageClass());
2931   printer << " \"" << sc << "\" " << ptr();
2932 
2933   printMemoryAccessAttribute(*this, printer, elidedAttrs);
2934 
2935   printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
2936   printer << " : " << getType();
2937 }
2938 
verify()2939 LogicalResult spirv::LoadOp::verify() {
2940   // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
2941   // type with fixed size; i.e., it cannot be, nor include, any
2942   // OpTypeRuntimeArray types."
2943   if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) {
2944     return failure();
2945   }
2946   return verifyMemoryAccessAttribute(*this);
2947 }
2948 
2949 //===----------------------------------------------------------------------===//
2950 // spv.mlir.loop
2951 //===----------------------------------------------------------------------===//
2952 
build(OpBuilder & builder,OperationState & state)2953 void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
2954   state.addAttribute("loop_control",
2955                      builder.getI32IntegerAttr(
2956                          static_cast<uint32_t>(spirv::LoopControl::None)));
2957   state.addRegion();
2958 }
2959 
parse(OpAsmParser & parser,OperationState & state)2960 ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &state) {
2961   if (parseControlAttribute<spirv::LoopControl>(parser, state))
2962     return failure();
2963   return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
2964                             /*argTypes=*/{});
2965 }
2966 
print(OpAsmPrinter & printer)2967 void spirv::LoopOp::print(OpAsmPrinter &printer) {
2968   auto control = loop_control();
2969   if (control != spirv::LoopControl::None)
2970     printer << " control(" << spirv::stringifyLoopControl(control) << ")";
2971   printer << ' ';
2972   printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
2973                       /*printBlockTerminators=*/true);
2974 }
2975 
2976 /// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
2977 /// given `dstBlock`.
hasOneBranchOpTo(Block & srcBlock,Block & dstBlock)2978 static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
2979   // Check that there is only one op in the `srcBlock`.
2980   if (!llvm::hasSingleElement(srcBlock))
2981     return false;
2982 
2983   auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
2984   return branchOp && branchOp.getSuccessor() == &dstBlock;
2985 }
2986 
verifyRegions()2987 LogicalResult spirv::LoopOp::verifyRegions() {
2988   auto *op = getOperation();
2989 
2990   // We need to verify that the blocks follow the following layout:
2991   //
2992   //                     +-------------+
2993   //                     | entry block |
2994   //                     +-------------+
2995   //                            |
2996   //                            v
2997   //                     +-------------+
2998   //                     | loop header | <-----+
2999   //                     +-------------+       |
3000   //                                           |
3001   //                           ...             |
3002   //                          \ | /            |
3003   //                            v              |
3004   //                    +---------------+      |
3005   //                    | loop continue | -----+
3006   //                    +---------------+
3007   //
3008   //                           ...
3009   //                          \ | /
3010   //                            v
3011   //                     +-------------+
3012   //                     | merge block |
3013   //                     +-------------+
3014 
3015   auto &region = op->getRegion(0);
3016   // Allow empty region as a degenerated case, which can come from
3017   // optimizations.
3018   if (region.empty())
3019     return success();
3020 
3021   // The last block is the merge block.
3022   Block &merge = region.back();
3023   if (!isMergeBlock(merge))
3024     return emitOpError(
3025         "last block must be the merge block with only one 'spv.mlir.merge' op");
3026 
3027   if (std::next(region.begin()) == region.end())
3028     return emitOpError(
3029         "must have an entry block branching to the loop header block");
3030   // The first block is the entry block.
3031   Block &entry = region.front();
3032 
3033   if (std::next(region.begin(), 2) == region.end())
3034     return emitOpError(
3035         "must have a loop header block branched from the entry block");
3036   // The second block is the loop header block.
3037   Block &header = *std::next(region.begin(), 1);
3038 
3039   if (!hasOneBranchOpTo(entry, header))
3040     return emitOpError(
3041         "entry block must only have one 'spv.Branch' op to the second block");
3042 
3043   if (std::next(region.begin(), 3) == region.end())
3044     return emitOpError(
3045         "requires a loop continue block branching to the loop header block");
3046   // The second to last block is the loop continue block.
3047   Block &cont = *std::prev(region.end(), 2);
3048 
3049   // Make sure that we have a branch from the loop continue block to the loop
3050   // header block.
3051   if (llvm::none_of(
3052           llvm::seq<unsigned>(0, cont.getNumSuccessors()),
3053           [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
3054     return emitOpError("second to last block must be the loop continue "
3055                        "block that branches to the loop header block");
3056 
3057   // Make sure that no other blocks (except the entry and loop continue block)
3058   // branches to the loop header block.
3059   for (auto &block : llvm::make_range(std::next(region.begin(), 2),
3060                                       std::prev(region.end(), 2))) {
3061     for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
3062       if (block.getSuccessor(i) == &header) {
3063         return emitOpError("can only have the entry and loop continue "
3064                            "block branching to the loop header block");
3065       }
3066     }
3067   }
3068 
3069   return success();
3070 }
3071 
getEntryBlock()3072 Block *spirv::LoopOp::getEntryBlock() {
3073   assert(!body().empty() && "op region should not be empty!");
3074   return &body().front();
3075 }
3076 
getHeaderBlock()3077 Block *spirv::LoopOp::getHeaderBlock() {
3078   assert(!body().empty() && "op region should not be empty!");
3079   // The second block is the loop header block.
3080   return &*std::next(body().begin());
3081 }
3082 
getContinueBlock()3083 Block *spirv::LoopOp::getContinueBlock() {
3084   assert(!body().empty() && "op region should not be empty!");
3085   // The second to last block is the loop continue block.
3086   return &*std::prev(body().end(), 2);
3087 }
3088 
getMergeBlock()3089 Block *spirv::LoopOp::getMergeBlock() {
3090   assert(!body().empty() && "op region should not be empty!");
3091   // The last block is the loop merge block.
3092   return &body().back();
3093 }
3094 
addEntryAndMergeBlock()3095 void spirv::LoopOp::addEntryAndMergeBlock() {
3096   assert(body().empty() && "entry and merge block already exist");
3097   body().push_back(new Block());
3098   auto *mergeBlock = new Block();
3099   body().push_back(mergeBlock);
3100   OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3101 
3102   // Add a spv.mlir.merge op into the merge block.
3103   builder.create<spirv::MergeOp>(getLoc());
3104 }
3105 
3106 //===----------------------------------------------------------------------===//
3107 // spv.MemoryBarrierOp
3108 //===----------------------------------------------------------------------===//
3109 
verify()3110 LogicalResult spirv::MemoryBarrierOp::verify() {
3111   return verifyMemorySemantics(getOperation(), memory_semantics());
3112 }
3113 
3114 //===----------------------------------------------------------------------===//
3115 // spv.mlir.merge
3116 //===----------------------------------------------------------------------===//
3117 
verify()3118 LogicalResult spirv::MergeOp::verify() {
3119   auto *parentOp = (*this)->getParentOp();
3120   if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
3121     return emitOpError(
3122         "expected parent op to be 'spv.mlir.selection' or 'spv.mlir.loop'");
3123 
3124   // TODO: This check should be done in `verifyRegions` of parent op.
3125   Block &parentLastBlock = (*this)->getParentRegion()->back();
3126   if (getOperation() != parentLastBlock.getTerminator())
3127     return emitOpError("can only be used in the last block of "
3128                        "'spv.mlir.selection' or 'spv.mlir.loop'");
3129   return success();
3130 }
3131 
3132 //===----------------------------------------------------------------------===//
3133 // spv.module
3134 //===----------------------------------------------------------------------===//
3135 
build(OpBuilder & builder,OperationState & state,Optional<StringRef> name)3136 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3137                             Optional<StringRef> name) {
3138   OpBuilder::InsertionGuard guard(builder);
3139   builder.createBlock(state.addRegion());
3140   if (name) {
3141     state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
3142                             builder.getStringAttr(*name));
3143   }
3144 }
3145 
build(OpBuilder & builder,OperationState & state,spirv::AddressingModel addressingModel,spirv::MemoryModel memoryModel,Optional<VerCapExtAttr> vceTriple,Optional<StringRef> name)3146 void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3147                             spirv::AddressingModel addressingModel,
3148                             spirv::MemoryModel memoryModel,
3149                             Optional<VerCapExtAttr> vceTriple,
3150                             Optional<StringRef> name) {
3151   state.addAttribute(
3152       "addressing_model",
3153       builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
3154   state.addAttribute("memory_model", builder.getI32IntegerAttr(
3155                                          static_cast<int32_t>(memoryModel)));
3156   OpBuilder::InsertionGuard guard(builder);
3157   builder.createBlock(state.addRegion());
3158   if (vceTriple)
3159     state.addAttribute(getVCETripleAttrName(), *vceTriple);
3160   if (name)
3161     state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
3162                        builder.getStringAttr(*name));
3163 }
3164 
parse(OpAsmParser & parser,OperationState & state)3165 ParseResult spirv::ModuleOp::parse(OpAsmParser &parser, OperationState &state) {
3166   Region *body = state.addRegion();
3167 
3168   // If the name is present, parse it.
3169   StringAttr nameAttr;
3170   (void)parser.parseOptionalSymbolName(
3171       nameAttr, mlir::SymbolTable::getSymbolAttrName(), state.attributes);
3172 
3173   // Parse attributes
3174   spirv::AddressingModel addrModel;
3175   spirv::MemoryModel memoryModel;
3176   if (::parseEnumKeywordAttr(addrModel, parser, state) ||
3177       ::parseEnumKeywordAttr(memoryModel, parser, state))
3178     return failure();
3179 
3180   if (succeeded(parser.parseOptionalKeyword("requires"))) {
3181     spirv::VerCapExtAttr vceTriple;
3182     if (parser.parseAttribute(vceTriple,
3183                               spirv::ModuleOp::getVCETripleAttrName(),
3184                               state.attributes))
3185       return failure();
3186   }
3187 
3188   if (parser.parseOptionalAttrDictWithKeyword(state.attributes) ||
3189       parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
3190     return failure();
3191 
3192   // Make sure we have at least one block.
3193   if (body->empty())
3194     body->push_back(new Block());
3195 
3196   return success();
3197 }
3198 
print(OpAsmPrinter & printer)3199 void spirv::ModuleOp::print(OpAsmPrinter &printer) {
3200   if (Optional<StringRef> name = getName()) {
3201     printer << ' ';
3202     printer.printSymbolName(*name);
3203   }
3204 
3205   SmallVector<StringRef, 2> elidedAttrs;
3206 
3207   printer << " " << spirv::stringifyAddressingModel(addressing_model()) << " "
3208           << spirv::stringifyMemoryModel(memory_model());
3209   auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
3210   auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
3211   elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
3212                       mlir::SymbolTable::getSymbolAttrName()});
3213 
3214   if (Optional<spirv::VerCapExtAttr> triple = vce_triple()) {
3215     printer << " requires " << *triple;
3216     elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
3217   }
3218 
3219   printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
3220   printer << ' ';
3221   printer.printRegion(getRegion());
3222 }
3223 
verifyRegions()3224 LogicalResult spirv::ModuleOp::verifyRegions() {
3225   Dialect *dialect = (*this)->getDialect();
3226   DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
3227       entryPoints;
3228   mlir::SymbolTable table(*this);
3229 
3230   for (auto &op : *getBody()) {
3231     if (op.getDialect() != dialect)
3232       return op.emitError("'spv.module' can only contain spv.* ops");
3233 
3234     // For EntryPoint op, check that the function and execution model is not
3235     // duplicated in EntryPointOps. Also verify that the interface specified
3236     // comes from globalVariables here to make this check cheaper.
3237     if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
3238       auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
3239       if (!funcOp) {
3240         return entryPointOp.emitError("function '")
3241                << entryPointOp.fn() << "' not found in 'spv.module'";
3242       }
3243       if (auto interface = entryPointOp.interface()) {
3244         for (Attribute varRef : interface) {
3245           auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
3246           if (!varSymRef) {
3247             return entryPointOp.emitError(
3248                        "expected symbol reference for interface "
3249                        "specification instead of '")
3250                    << varRef;
3251           }
3252           auto variableOp =
3253               table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
3254           if (!variableOp) {
3255             return entryPointOp.emitError("expected spv.GlobalVariable "
3256                                           "symbol reference instead of'")
3257                    << varSymRef << "'";
3258           }
3259         }
3260       }
3261 
3262       auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
3263           funcOp, entryPointOp.execution_model());
3264       auto entryPtIt = entryPoints.find(key);
3265       if (entryPtIt != entryPoints.end()) {
3266         return entryPointOp.emitError("duplicate of a previous EntryPointOp");
3267       }
3268       entryPoints[key] = entryPointOp;
3269     } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
3270       if (funcOp.isExternal())
3271         return op.emitError("'spv.module' cannot contain external functions");
3272 
3273       // TODO: move this check to spv.func.
3274       for (auto &block : funcOp)
3275         for (auto &op : block) {
3276           if (op.getDialect() != dialect)
3277             return op.emitError(
3278                 "functions in 'spv.module' can only contain spv.* ops");
3279         }
3280     }
3281   }
3282 
3283   return success();
3284 }
3285 
3286 //===----------------------------------------------------------------------===//
3287 // spv.mlir.referenceof
3288 //===----------------------------------------------------------------------===//
3289 
verify()3290 LogicalResult spirv::ReferenceOfOp::verify() {
3291   auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
3292       (*this)->getParentOp(), spec_constAttr());
3293   Type constType;
3294 
3295   auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
3296   if (specConstOp)
3297     constType = specConstOp.default_value().getType();
3298 
3299   auto specConstCompositeOp =
3300       dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
3301   if (specConstCompositeOp)
3302     constType = specConstCompositeOp.type();
3303 
3304   if (!specConstOp && !specConstCompositeOp)
3305     return emitOpError(
3306         "expected spv.SpecConstant or spv.SpecConstantComposite symbol");
3307 
3308   if (reference().getType() != constType)
3309     return emitOpError("result type mismatch with the referenced "
3310                        "specialization constant's type");
3311 
3312   return success();
3313 }
3314 
3315 //===----------------------------------------------------------------------===//
3316 // spv.Return
3317 //===----------------------------------------------------------------------===//
3318 
verify()3319 LogicalResult spirv::ReturnOp::verify() {
3320   // Verification is performed in spv.func op.
3321   return success();
3322 }
3323 
3324 //===----------------------------------------------------------------------===//
3325 // spv.ReturnValue
3326 //===----------------------------------------------------------------------===//
3327 
verify()3328 LogicalResult spirv::ReturnValueOp::verify() {
3329   // Verification is performed in spv.func op.
3330   return success();
3331 }
3332 
3333 //===----------------------------------------------------------------------===//
3334 // spv.Select
3335 //===----------------------------------------------------------------------===//
3336 
verify()3337 LogicalResult spirv::SelectOp::verify() {
3338   if (auto conditionTy = condition().getType().dyn_cast<VectorType>()) {
3339     auto resultVectorTy = result().getType().dyn_cast<VectorType>();
3340     if (!resultVectorTy) {
3341       return emitOpError("result expected to be of vector type when "
3342                          "condition is of vector type");
3343     }
3344     if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
3345       return emitOpError("result should have the same number of elements as "
3346                          "the condition when condition is of vector type");
3347     }
3348   }
3349   return success();
3350 }
3351 
3352 //===----------------------------------------------------------------------===//
3353 // spv.mlir.selection
3354 //===----------------------------------------------------------------------===//
3355 
parse(OpAsmParser & parser,OperationState & state)3356 ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
3357                                       OperationState &state) {
3358   if (parseControlAttribute<spirv::SelectionControl>(parser, state))
3359     return failure();
3360   return parser.parseRegion(*state.addRegion(), /*arguments=*/{},
3361                             /*argTypes=*/{});
3362 }
3363 
print(OpAsmPrinter & printer)3364 void spirv::SelectionOp::print(OpAsmPrinter &printer) {
3365   auto control = selection_control();
3366   if (control != spirv::SelectionControl::None)
3367     printer << " control(" << spirv::stringifySelectionControl(control) << ")";
3368   printer << ' ';
3369   printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3370                       /*printBlockTerminators=*/true);
3371 }
3372 
verifyRegions()3373 LogicalResult spirv::SelectionOp::verifyRegions() {
3374   auto *op = getOperation();
3375 
3376   // We need to verify that the blocks follow the following layout:
3377   //
3378   //                     +--------------+
3379   //                     | header block |
3380   //                     +--------------+
3381   //                          / | \
3382   //                           ...
3383   //
3384   //
3385   //         +---------+   +---------+   +---------+
3386   //         | case #0 |   | case #1 |   | case #2 |  ...
3387   //         +---------+   +---------+   +---------+
3388   //
3389   //
3390   //                           ...
3391   //                          \ | /
3392   //                            v
3393   //                     +-------------+
3394   //                     | merge block |
3395   //                     +-------------+
3396 
3397   auto &region = op->getRegion(0);
3398   // Allow empty region as a degenerated case, which can come from
3399   // optimizations.
3400   if (region.empty())
3401     return success();
3402 
3403   // The last block is the merge block.
3404   if (!isMergeBlock(region.back()))
3405     return emitOpError(
3406         "last block must be the merge block with only one 'spv.mlir.merge' op");
3407 
3408   if (std::next(region.begin()) == region.end())
3409     return emitOpError("must have a selection header block");
3410 
3411   return success();
3412 }
3413 
getHeaderBlock()3414 Block *spirv::SelectionOp::getHeaderBlock() {
3415   assert(!body().empty() && "op region should not be empty!");
3416   // The first block is the loop header block.
3417   return &body().front();
3418 }
3419 
getMergeBlock()3420 Block *spirv::SelectionOp::getMergeBlock() {
3421   assert(!body().empty() && "op region should not be empty!");
3422   // The last block is the loop merge block.
3423   return &body().back();
3424 }
3425 
addMergeBlock()3426 void spirv::SelectionOp::addMergeBlock() {
3427   assert(body().empty() && "entry and merge block already exist");
3428   auto *mergeBlock = new Block();
3429   body().push_back(mergeBlock);
3430   OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3431 
3432   // Add a spv.mlir.merge op into the merge block.
3433   builder.create<spirv::MergeOp>(getLoc());
3434 }
3435 
createIfThen(Location loc,Value condition,function_ref<void (OpBuilder & builder)> thenBody,OpBuilder & builder)3436 spirv::SelectionOp spirv::SelectionOp::createIfThen(
3437     Location loc, Value condition,
3438     function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
3439   auto selectionOp =
3440       builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
3441 
3442   selectionOp.addMergeBlock();
3443   Block *mergeBlock = selectionOp.getMergeBlock();
3444   Block *thenBlock = nullptr;
3445 
3446   // Build the "then" block.
3447   {
3448     OpBuilder::InsertionGuard guard(builder);
3449     thenBlock = builder.createBlock(mergeBlock);
3450     thenBody(builder);
3451     builder.create<spirv::BranchOp>(loc, mergeBlock);
3452   }
3453 
3454   // Build the header block.
3455   {
3456     OpBuilder::InsertionGuard guard(builder);
3457     builder.createBlock(thenBlock);
3458     builder.create<spirv::BranchConditionalOp>(
3459         loc, condition, thenBlock,
3460         /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
3461         /*falseArguments=*/ArrayRef<Value>());
3462   }
3463 
3464   return selectionOp;
3465 }
3466 
3467 //===----------------------------------------------------------------------===//
3468 // spv.SpecConstant
3469 //===----------------------------------------------------------------------===//
3470 
parse(OpAsmParser & parser,OperationState & state)3471 ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
3472                                          OperationState &state) {
3473   StringAttr nameAttr;
3474   Attribute valueAttr;
3475 
3476   if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3477                              state.attributes))
3478     return failure();
3479 
3480   // Parse optional spec_id.
3481   if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
3482     IntegerAttr specIdAttr;
3483     if (parser.parseLParen() ||
3484         parser.parseAttribute(specIdAttr, kSpecIdAttrName, state.attributes) ||
3485         parser.parseRParen())
3486       return failure();
3487   }
3488 
3489   if (parser.parseEqual() ||
3490       parser.parseAttribute(valueAttr, kDefaultValueAttrName, state.attributes))
3491     return failure();
3492 
3493   return success();
3494 }
3495 
print(OpAsmPrinter & printer)3496 void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
3497   printer << ' ';
3498   printer.printSymbolName(sym_name());
3499   if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3500     printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
3501   printer << " = " << default_value();
3502 }
3503 
verify()3504 LogicalResult spirv::SpecConstantOp::verify() {
3505   if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3506     if (specID.getValue().isNegative())
3507       return emitOpError("SpecId cannot be negative");
3508 
3509   auto value = default_value();
3510   if (value.isa<IntegerAttr, FloatAttr>()) {
3511     // Make sure bitwidth is allowed.
3512     if (!value.getType().isa<spirv::SPIRVType>())
3513       return emitOpError("default value bitwidth disallowed");
3514     return success();
3515   }
3516   return emitOpError(
3517       "default value can only be a bool, integer, or float scalar");
3518 }
3519 
3520 //===----------------------------------------------------------------------===//
3521 // spv.StoreOp
3522 //===----------------------------------------------------------------------===//
3523 
parse(OpAsmParser & parser,OperationState & state)3524 ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &state) {
3525   // Parse the storage class specification
3526   spirv::StorageClass storageClass;
3527   SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
3528   auto loc = parser.getCurrentLocation();
3529   Type elementType;
3530   if (parseEnumStrAttr(storageClass, parser) ||
3531       parser.parseOperandList(operandInfo, 2) ||
3532       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3533       parser.parseType(elementType)) {
3534     return failure();
3535   }
3536 
3537   auto ptrType = spirv::PointerType::get(elementType, storageClass);
3538   if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
3539                              state.operands)) {
3540     return failure();
3541   }
3542   return success();
3543 }
3544 
print(OpAsmPrinter & printer)3545 void spirv::StoreOp::print(OpAsmPrinter &printer) {
3546   SmallVector<StringRef, 4> elidedAttrs;
3547   StringRef sc = stringifyStorageClass(
3548       ptr().getType().cast<spirv::PointerType>().getStorageClass());
3549   printer << " \"" << sc << "\" " << ptr() << ", " << value();
3550 
3551   printMemoryAccessAttribute(*this, printer, elidedAttrs);
3552 
3553   printer << " : " << value().getType();
3554   printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3555 }
3556 
verify()3557 LogicalResult spirv::StoreOp::verify() {
3558   // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
3559   // OpTypePointer whose Type operand is the same as the type of Object."
3560   if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value())))
3561     return failure();
3562   return verifyMemoryAccessAttribute(*this);
3563 }
3564 
3565 //===----------------------------------------------------------------------===//
3566 // spv.Unreachable
3567 //===----------------------------------------------------------------------===//
3568 
verify()3569 LogicalResult spirv::UnreachableOp::verify() {
3570   auto *block = (*this)->getBlock();
3571   // Fast track: if this is in entry block, its invalid. Otherwise, if no
3572   // predecessors, it's valid.
3573   if (block->isEntryBlock())
3574     return emitOpError("cannot be used in reachable block");
3575   if (block->hasNoPredecessors())
3576     return success();
3577 
3578   // TODO: further verification needs to analyze reachability from
3579   // the entry block.
3580 
3581   return success();
3582 }
3583 
3584 //===----------------------------------------------------------------------===//
3585 // spv.Variable
3586 //===----------------------------------------------------------------------===//
3587 
parse(OpAsmParser & parser,OperationState & state)3588 ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
3589                                      OperationState &state) {
3590   // Parse optional initializer
3591   Optional<OpAsmParser::UnresolvedOperand> initInfo;
3592   if (succeeded(parser.parseOptionalKeyword("init"))) {
3593     initInfo = OpAsmParser::UnresolvedOperand();
3594     if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3595         parser.parseRParen())
3596       return failure();
3597   }
3598 
3599   if (parseVariableDecorations(parser, state)) {
3600     return failure();
3601   }
3602 
3603   // Parse result pointer type
3604   Type type;
3605   if (parser.parseColon())
3606     return failure();
3607   auto loc = parser.getCurrentLocation();
3608   if (parser.parseType(type))
3609     return failure();
3610 
3611   auto ptrType = type.dyn_cast<spirv::PointerType>();
3612   if (!ptrType)
3613     return parser.emitError(loc, "expected spv.ptr type");
3614   state.addTypes(ptrType);
3615 
3616   // Resolve the initializer operand
3617   if (initInfo) {
3618     if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3619                               state.operands))
3620       return failure();
3621   }
3622 
3623   auto attr = parser.getBuilder().getI32IntegerAttr(
3624       llvm::bit_cast<int32_t>(ptrType.getStorageClass()));
3625   state.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3626 
3627   return success();
3628 }
3629 
print(OpAsmPrinter & printer)3630 void spirv::VariableOp::print(OpAsmPrinter &printer) {
3631   SmallVector<StringRef, 4> elidedAttrs{
3632       spirv::attributeName<spirv::StorageClass>()};
3633   // Print optional initializer
3634   if (getNumOperands() != 0)
3635     printer << " init(" << initializer() << ")";
3636 
3637   printVariableDecorations(*this, printer, elidedAttrs);
3638   printer << " : " << getType();
3639 }
3640 
verify()3641 LogicalResult spirv::VariableOp::verify() {
3642   // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3643   // object. It cannot be Generic. It must be the same as the Storage Class
3644   // operand of the Result Type."
3645   if (storage_class() != spirv::StorageClass::Function) {
3646     return emitOpError(
3647         "can only be used to model function-level variables. Use "
3648         "spv.GlobalVariable for module-level variables.");
3649   }
3650 
3651   auto pointerType = pointer().getType().cast<spirv::PointerType>();
3652   if (storage_class() != pointerType.getStorageClass())
3653     return emitOpError(
3654         "storage class must match result pointer's storage class");
3655 
3656   if (getNumOperands() != 0) {
3657     // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3658     // a global (module scope) OpVariable instruction".
3659     auto *initOp = getOperand(0).getDefiningOp();
3660     if (!initOp || !isa<spirv::ConstantOp,    // for normal constant
3661                         spirv::ReferenceOfOp, // for spec constant
3662                         spirv::AddressOfOp>(initOp))
3663       return emitOpError("initializer must be the result of a "
3664                          "constant or spv.GlobalVariable op");
3665   }
3666 
3667   // TODO: generate these strings using ODS.
3668   auto *op = getOperation();
3669   auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3670       stringifyDecoration(spirv::Decoration::DescriptorSet));
3671   auto bindingName = llvm::convertToSnakeFromCamelCase(
3672       stringifyDecoration(spirv::Decoration::Binding));
3673   auto builtInName = llvm::convertToSnakeFromCamelCase(
3674       stringifyDecoration(spirv::Decoration::BuiltIn));
3675 
3676   for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3677     if (op->getAttr(attr))
3678       return emitOpError("cannot have '")
3679              << attr << "' attribute (only allowed in spv.GlobalVariable)";
3680   }
3681 
3682   return success();
3683 }
3684 
3685 //===----------------------------------------------------------------------===//
3686 // spv.VectorShuffle
3687 //===----------------------------------------------------------------------===//
3688 
verify()3689 LogicalResult spirv::VectorShuffleOp::verify() {
3690   VectorType resultType = getType().cast<VectorType>();
3691 
3692   size_t numResultElements = resultType.getNumElements();
3693   if (numResultElements != components().size())
3694     return emitOpError("result type element count (")
3695            << numResultElements
3696            << ") mismatch with the number of component selectors ("
3697            << components().size() << ")";
3698 
3699   size_t totalSrcElements =
3700       vector1().getType().cast<VectorType>().getNumElements() +
3701       vector2().getType().cast<VectorType>().getNumElements();
3702 
3703   for (const auto &selector : components().getAsValueRange<IntegerAttr>()) {
3704     uint32_t index = selector.getZExtValue();
3705     if (index >= totalSrcElements &&
3706         index != std::numeric_limits<uint32_t>().max())
3707       return emitOpError("component selector ")
3708              << index << " out of range: expected to be in [0, "
3709              << totalSrcElements << ") or 0xffffffff";
3710   }
3711   return success();
3712 }
3713 
3714 //===----------------------------------------------------------------------===//
3715 // spv.CooperativeMatrixLoadNV
3716 //===----------------------------------------------------------------------===//
3717 
parse(OpAsmParser & parser,OperationState & state)3718 ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
3719                                                     OperationState &state) {
3720   SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
3721   Type strideType = parser.getBuilder().getIntegerType(32);
3722   Type columnMajorType = parser.getBuilder().getIntegerType(1);
3723   Type ptrType;
3724   Type elementType;
3725   if (parser.parseOperandList(operandInfo, 3) ||
3726       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3727       parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3728     return failure();
3729   }
3730   if (parser.resolveOperands(operandInfo,
3731                              {ptrType, strideType, columnMajorType},
3732                              parser.getNameLoc(), state.operands)) {
3733     return failure();
3734   }
3735 
3736   state.addTypes(elementType);
3737   return success();
3738 }
3739 
print(OpAsmPrinter & printer)3740 void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) {
3741   printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
3742   // Print optional memory access attribute.
3743   if (auto memAccess = memory_access())
3744     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3745   printer << " : " << pointer().getType() << " as " << getType();
3746 }
3747 
verifyPointerAndCoopMatrixType(Operation * op,Type pointer,Type coopMatrix)3748 static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
3749                                                     Type coopMatrix) {
3750   Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3751   if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3752     return op->emitError(
3753                "Pointer must point to a scalar or vector type but provided ")
3754            << pointeeType;
3755   spirv::StorageClass storage =
3756       pointer.cast<spirv::PointerType>().getStorageClass();
3757   if (storage != spirv::StorageClass::Workgroup &&
3758       storage != spirv::StorageClass::StorageBuffer &&
3759       storage != spirv::StorageClass::PhysicalStorageBuffer)
3760     return op->emitError(
3761                "Pointer storage class must be Workgroup, StorageBuffer or "
3762                "PhysicalStorageBufferEXT but provided ")
3763            << stringifyStorageClass(storage);
3764   return success();
3765 }
3766 
verify()3767 LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
3768   return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
3769                                         result().getType());
3770 }
3771 
3772 //===----------------------------------------------------------------------===//
3773 // spv.CooperativeMatrixStoreNV
3774 //===----------------------------------------------------------------------===//
3775 
parse(OpAsmParser & parser,OperationState & state)3776 ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
3777                                                      OperationState &state) {
3778   SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
3779   Type strideType = parser.getBuilder().getIntegerType(32);
3780   Type columnMajorType = parser.getBuilder().getIntegerType(1);
3781   Type ptrType;
3782   Type elementType;
3783   if (parser.parseOperandList(operandInfo, 4) ||
3784       parseMemoryAccessAttributes(parser, state) || parser.parseColon() ||
3785       parser.parseType(ptrType) || parser.parseComma() ||
3786       parser.parseType(elementType)) {
3787     return failure();
3788   }
3789   if (parser.resolveOperands(
3790           operandInfo, {ptrType, elementType, strideType, columnMajorType},
3791           parser.getNameLoc(), state.operands)) {
3792     return failure();
3793   }
3794 
3795   return success();
3796 }
3797 
print(OpAsmPrinter & printer)3798 void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
3799   printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
3800           << columnmajor();
3801   // Print optional memory access attribute.
3802   if (auto memAccess = memory_access())
3803     printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3804   printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
3805 }
3806 
verify()3807 LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
3808   return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
3809                                         object().getType());
3810 }
3811 
3812 //===----------------------------------------------------------------------===//
3813 // spv.CooperativeMatrixMulAddNV
3814 //===----------------------------------------------------------------------===//
3815 
3816 static LogicalResult
verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op)3817 verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
3818   if (op.c().getType() != op.result().getType())
3819     return op.emitOpError("result and third operand must have the same type");
3820   auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
3821   auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
3822   auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
3823   auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
3824   if (typeA.getRows() != typeR.getRows() ||
3825       typeA.getColumns() != typeB.getRows() ||
3826       typeB.getColumns() != typeR.getColumns())
3827     return op.emitOpError("matrix size must match");
3828   if (typeR.getScope() != typeA.getScope() ||
3829       typeR.getScope() != typeB.getScope() ||
3830       typeR.getScope() != typeC.getScope())
3831     return op.emitOpError("matrix scope must match");
3832   if (typeA.getElementType() != typeB.getElementType() ||
3833       typeR.getElementType() != typeC.getElementType())
3834     return op.emitOpError("matrix element type must match");
3835   return success();
3836 }
3837 
verify()3838 LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
3839   return verifyCoopMatrixMulAdd(*this);
3840 }
3841 
3842 //===----------------------------------------------------------------------===//
3843 // spv.MatrixTimesScalar
3844 //===----------------------------------------------------------------------===//
3845 
verify()3846 LogicalResult spirv::MatrixTimesScalarOp::verify() {
3847   // We already checked that result and matrix are both of matrix type in the
3848   // auto-generated verify method.
3849 
3850   auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
3851   auto resultMatrix = result().getType().cast<spirv::MatrixType>();
3852 
3853   // Check that the scalar type is the same as the matrix element type.
3854   if (scalar().getType() != inputMatrix.getElementType())
3855     return emitError("input matrix components' type and scaling value must "
3856                      "have the same type");
3857 
3858   // Note that the next three checks could be done using the AllTypesMatch
3859   // trait in the Op definition file but it generates a vague error message.
3860 
3861   // Check that the input and result matrices have the same columns' count
3862   if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
3863     return emitError("input and result matrices must have the same "
3864                      "number of columns");
3865 
3866   // Check that the input and result matrices' have the same rows count
3867   if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
3868     return emitError("input and result matrices' columns must have "
3869                      "the same size");
3870 
3871   // Check that the input and result matrices' have the same component type
3872   if (inputMatrix.getElementType() != resultMatrix.getElementType())
3873     return emitError("input and result matrices' columns must have "
3874                      "the same component type");
3875 
3876   return success();
3877 }
3878 
3879 //===----------------------------------------------------------------------===//
3880 // spv.CopyMemory
3881 //===----------------------------------------------------------------------===//
3882 
print(OpAsmPrinter & printer)3883 void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
3884   printer << ' ';
3885 
3886   StringRef targetStorageClass = stringifyStorageClass(
3887       target().getType().cast<spirv::PointerType>().getStorageClass());
3888   printer << " \"" << targetStorageClass << "\" " << target() << ", ";
3889 
3890   StringRef sourceStorageClass = stringifyStorageClass(
3891       source().getType().cast<spirv::PointerType>().getStorageClass());
3892   printer << " \"" << sourceStorageClass << "\" " << source();
3893 
3894   SmallVector<StringRef, 4> elidedAttrs;
3895   printMemoryAccessAttribute(*this, printer, elidedAttrs);
3896   printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
3897                                    source_memory_access(), source_alignment());
3898 
3899   printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3900 
3901   Type pointeeType =
3902       target().getType().cast<spirv::PointerType>().getPointeeType();
3903   printer << " : " << pointeeType;
3904 }
3905 
parse(OpAsmParser & parser,OperationState & state)3906 ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
3907                                        OperationState &state) {
3908   spirv::StorageClass targetStorageClass;
3909   OpAsmParser::UnresolvedOperand targetPtrInfo;
3910 
3911   spirv::StorageClass sourceStorageClass;
3912   OpAsmParser::UnresolvedOperand sourcePtrInfo;
3913 
3914   Type elementType;
3915 
3916   if (parseEnumStrAttr(targetStorageClass, parser) ||
3917       parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
3918       parseEnumStrAttr(sourceStorageClass, parser) ||
3919       parser.parseOperand(sourcePtrInfo) ||
3920       parseMemoryAccessAttributes(parser, state)) {
3921     return failure();
3922   }
3923 
3924   if (!parser.parseOptionalComma()) {
3925     // Parse 2nd memory access attributes.
3926     if (parseSourceMemoryAccessAttributes(parser, state)) {
3927       return failure();
3928     }
3929   }
3930 
3931   if (parser.parseColon() || parser.parseType(elementType))
3932     return failure();
3933 
3934   if (parser.parseOptionalAttrDict(state.attributes))
3935     return failure();
3936 
3937   auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
3938   auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
3939 
3940   if (parser.resolveOperand(targetPtrInfo, targetPtrType, state.operands) ||
3941       parser.resolveOperand(sourcePtrInfo, sourcePtrType, state.operands)) {
3942     return failure();
3943   }
3944 
3945   return success();
3946 }
3947 
verify()3948 LogicalResult spirv::CopyMemoryOp::verify() {
3949   Type targetType =
3950       target().getType().cast<spirv::PointerType>().getPointeeType();
3951 
3952   Type sourceType =
3953       source().getType().cast<spirv::PointerType>().getPointeeType();
3954 
3955   if (targetType != sourceType)
3956     return emitOpError("both operands must be pointers to the same type");
3957 
3958   if (failed(verifyMemoryAccessAttribute(*this)))
3959     return failure();
3960 
3961   // TODO - According to the spec:
3962   //
3963   // If two masks are present, the first applies to Target and cannot include
3964   // MakePointerVisible, and the second applies to Source and cannot include
3965   // MakePointerAvailable.
3966   //
3967   // Add such verification here.
3968 
3969   return verifySourceMemoryAccessAttribute(*this);
3970 }
3971 
3972 //===----------------------------------------------------------------------===//
3973 // spv.Transpose
3974 //===----------------------------------------------------------------------===//
3975 
verify()3976 LogicalResult spirv::TransposeOp::verify() {
3977   auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
3978   auto resultMatrix = result().getType().cast<spirv::MatrixType>();
3979 
3980   // Verify that the input and output matrices have correct shapes.
3981   if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
3982     return emitError("input matrix rows count must be equal to "
3983                      "output matrix columns count");
3984 
3985   if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
3986     return emitError("input matrix columns count must be equal to "
3987                      "output matrix rows count");
3988 
3989   // Verify that the input and output matrices have the same component type
3990   if (inputMatrix.getElementType() != resultMatrix.getElementType())
3991     return emitError("input and output matrices must have the same "
3992                      "component type");
3993 
3994   return success();
3995 }
3996 
3997 //===----------------------------------------------------------------------===//
3998 // spv.MatrixTimesMatrix
3999 //===----------------------------------------------------------------------===//
4000 
verify()4001 LogicalResult spirv::MatrixTimesMatrixOp::verify() {
4002   auto leftMatrix = leftmatrix().getType().cast<spirv::MatrixType>();
4003   auto rightMatrix = rightmatrix().getType().cast<spirv::MatrixType>();
4004   auto resultMatrix = result().getType().cast<spirv::MatrixType>();
4005 
4006   // left matrix columns' count and right matrix rows' count must be equal
4007   if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
4008     return emitError("left matrix columns' count must be equal to "
4009                      "the right matrix rows' count");
4010 
4011   // right and result matrices columns' count must be the same
4012   if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
4013     return emitError(
4014         "right and result matrices must have equal columns' count");
4015 
4016   // right and result matrices component type must be the same
4017   if (rightMatrix.getElementType() != resultMatrix.getElementType())
4018     return emitError("right and result matrices' component type must"
4019                      " be the same");
4020 
4021   // left and result matrices component type must be the same
4022   if (leftMatrix.getElementType() != resultMatrix.getElementType())
4023     return emitError("left and result matrices' component type"
4024                      " must be the same");
4025 
4026   // left and result matrices rows count must be the same
4027   if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
4028     return emitError("left and result matrices must have equal rows' count");
4029 
4030   return success();
4031 }
4032 
4033 //===----------------------------------------------------------------------===//
4034 // spv.SpecConstantComposite
4035 //===----------------------------------------------------------------------===//
4036 
parse(OpAsmParser & parser,OperationState & state)4037 ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
4038                                                   OperationState &state) {
4039 
4040   StringAttr compositeName;
4041   if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
4042                              state.attributes))
4043     return failure();
4044 
4045   if (parser.parseLParen())
4046     return failure();
4047 
4048   SmallVector<Attribute, 4> constituents;
4049 
4050   do {
4051     // The name of the constituent attribute isn't important
4052     const char *attrName = "spec_const";
4053     FlatSymbolRefAttr specConstRef;
4054     NamedAttrList attrs;
4055 
4056     if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
4057       return failure();
4058 
4059     constituents.push_back(specConstRef);
4060   } while (!parser.parseOptionalComma());
4061 
4062   if (parser.parseRParen())
4063     return failure();
4064 
4065   state.addAttribute(kCompositeSpecConstituentsName,
4066                      parser.getBuilder().getArrayAttr(constituents));
4067 
4068   Type type;
4069   if (parser.parseColonType(type))
4070     return failure();
4071 
4072   state.addAttribute(kTypeAttrName, TypeAttr::get(type));
4073 
4074   return success();
4075 }
4076 
print(OpAsmPrinter & printer)4077 void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
4078   printer << " ";
4079   printer.printSymbolName(sym_name());
4080   printer << " (";
4081   auto constituents = this->constituents().getValue();
4082 
4083   if (!constituents.empty())
4084     llvm::interleaveComma(constituents, printer);
4085 
4086   printer << ") : " << type();
4087 }
4088 
verify()4089 LogicalResult spirv::SpecConstantCompositeOp::verify() {
4090   auto cType = type().dyn_cast<spirv::CompositeType>();
4091   auto constituents = this->constituents().getValue();
4092 
4093   if (!cType)
4094     return emitError("result type must be a composite type, but provided ")
4095            << type();
4096 
4097   if (cType.isa<spirv::CooperativeMatrixNVType>())
4098     return emitError("unsupported composite type  ") << cType;
4099   if (constituents.size() != cType.getNumElements())
4100     return emitError("has incorrect number of operands: expected ")
4101            << cType.getNumElements() << ", but provided "
4102            << constituents.size();
4103 
4104   for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
4105     auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
4106 
4107     auto constituentSpecConstOp =
4108         dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
4109             (*this)->getParentOp(), constituent.getAttr()));
4110 
4111     if (constituentSpecConstOp.default_value().getType() !=
4112         cType.getElementType(index))
4113       return emitError("has incorrect types of operands: expected ")
4114              << cType.getElementType(index) << ", but provided "
4115              << constituentSpecConstOp.default_value().getType();
4116   }
4117 
4118   return success();
4119 }
4120 
4121 //===----------------------------------------------------------------------===//
4122 // spv.SpecConstantOperation
4123 //===----------------------------------------------------------------------===//
4124 
parse(OpAsmParser & parser,OperationState & state)4125 ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
4126                                                   OperationState &state) {
4127   Region *body = state.addRegion();
4128 
4129   if (parser.parseKeyword("wraps"))
4130     return failure();
4131 
4132   body->push_back(new Block);
4133   Block &block = body->back();
4134   Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
4135 
4136   if (!wrappedOp)
4137     return failure();
4138 
4139   OpBuilder builder(parser.getContext());
4140   builder.setInsertionPointToEnd(&block);
4141   builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
4142   state.location = wrappedOp->getLoc();
4143 
4144   state.addTypes(wrappedOp->getResult(0).getType());
4145 
4146   if (parser.parseOptionalAttrDict(state.attributes))
4147     return failure();
4148 
4149   return success();
4150 }
4151 
print(OpAsmPrinter & printer)4152 void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
4153   printer << " wraps ";
4154   printer.printGenericOp(&body().front().front());
4155 }
4156 
verifyRegions()4157 LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
4158   Block &block = getRegion().getBlocks().front();
4159 
4160   if (block.getOperations().size() != 2)
4161     return emitOpError("expected exactly 2 nested ops");
4162 
4163   Operation &enclosedOp = block.getOperations().front();
4164 
4165   if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
4166     return emitOpError("invalid enclosed op");
4167 
4168   for (auto operand : enclosedOp.getOperands())
4169     if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
4170              spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
4171       return emitOpError(
4172           "invalid operand, must be defined by a constant operation");
4173 
4174   return success();
4175 }
4176 
4177 //===----------------------------------------------------------------------===//
4178 // spv.GL.FrexpStruct
4179 //===----------------------------------------------------------------------===//
4180 
verify()4181 LogicalResult spirv::GLFrexpStructOp::verify() {
4182   spirv::StructType structTy = result().getType().dyn_cast<spirv::StructType>();
4183 
4184   if (structTy.getNumElements() != 2)
4185     return emitError("result type must be a struct type with two memebers");
4186 
4187   Type significandTy = structTy.getElementType(0);
4188   Type exponentTy = structTy.getElementType(1);
4189   VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
4190   IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
4191 
4192   Type operandTy = operand().getType();
4193   VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
4194   FloatType operandFTy = operandTy.dyn_cast<FloatType>();
4195 
4196   if (significandTy != operandTy)
4197     return emitError("member zero of the resulting struct type must be the "
4198                      "same type as the operand");
4199 
4200   if (exponentVecTy) {
4201     IntegerType componentIntTy =
4202         exponentVecTy.getElementType().dyn_cast<IntegerType>();
4203     if (!componentIntTy || componentIntTy.getWidth() != 32)
4204       return emitError("member one of the resulting struct type must"
4205                        "be a scalar or vector of 32 bit integer type");
4206   } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
4207     return emitError("member one of the resulting struct type "
4208                      "must be a scalar or vector of 32 bit integer type");
4209   }
4210 
4211   // Check that the two member types have the same number of components
4212   if (operandVecTy && exponentVecTy &&
4213       (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
4214     return success();
4215 
4216   if (operandFTy && exponentIntTy)
4217     return success();
4218 
4219   return emitError("member one of the resulting struct type must have the same "
4220                    "number of components as the operand type");
4221 }
4222 
4223 //===----------------------------------------------------------------------===//
4224 // spv.GL.Ldexp
4225 //===----------------------------------------------------------------------===//
4226 
verify()4227 LogicalResult spirv::GLLdexpOp::verify() {
4228   Type significandType = x().getType();
4229   Type exponentType = exp().getType();
4230 
4231   if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
4232     return emitOpError("operands must both be scalars or vectors");
4233 
4234   auto getNumElements = [](Type type) -> unsigned {
4235     if (auto vectorType = type.dyn_cast<VectorType>())
4236       return vectorType.getNumElements();
4237     return 1;
4238   };
4239 
4240   if (getNumElements(significandType) != getNumElements(exponentType))
4241     return emitOpError("operands must have the same number of elements");
4242 
4243   return success();
4244 }
4245 
4246 //===----------------------------------------------------------------------===//
4247 // spv.ImageDrefGather
4248 //===----------------------------------------------------------------------===//
4249 
verify()4250 LogicalResult spirv::ImageDrefGatherOp::verify() {
4251   VectorType resultType = result().getType().cast<VectorType>();
4252   auto sampledImageType =
4253       sampledimage().getType().cast<spirv::SampledImageType>();
4254   auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
4255 
4256   if (resultType.getNumElements() != 4)
4257     return emitOpError("result type must be a vector of four components");
4258 
4259   Type elementType = resultType.getElementType();
4260   Type sampledElementType = imageType.getElementType();
4261   if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
4262     return emitOpError(
4263         "the component type of result must be the same as sampled type of the "
4264         "underlying image type");
4265 
4266   spirv::Dim imageDim = imageType.getDim();
4267   spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
4268 
4269   if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
4270       imageDim != spirv::Dim::Rect)
4271     return emitOpError(
4272         "the Dim operand of the underlying image type must be 2D, Cube, or "
4273         "Rect");
4274 
4275   if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
4276     return emitOpError("the MS operand of the underlying image type must be 0");
4277 
4278   spirv::ImageOperandsAttr attr = imageoperandsAttr();
4279   auto operandArguments = operand_arguments();
4280 
4281   return verifyImageOperands(*this, attr, operandArguments);
4282 }
4283 
4284 //===----------------------------------------------------------------------===//
4285 // spv.ShiftLeftLogicalOp
4286 //===----------------------------------------------------------------------===//
4287 
verify()4288 LogicalResult spirv::ShiftLeftLogicalOp::verify() {
4289   return verifyShiftOp(*this);
4290 }
4291 
4292 //===----------------------------------------------------------------------===//
4293 // spv.ShiftRightArithmeticOp
4294 //===----------------------------------------------------------------------===//
4295 
verify()4296 LogicalResult spirv::ShiftRightArithmeticOp::verify() {
4297   return verifyShiftOp(*this);
4298 }
4299 
4300 //===----------------------------------------------------------------------===//
4301 // spv.ShiftRightLogicalOp
4302 //===----------------------------------------------------------------------===//
4303 
verify()4304 LogicalResult spirv::ShiftRightLogicalOp::verify() {
4305   return verifyShiftOp(*this);
4306 }
4307 
4308 //===----------------------------------------------------------------------===//
4309 // spv.ImageQuerySize
4310 //===----------------------------------------------------------------------===//
4311 
verify()4312 LogicalResult spirv::ImageQuerySizeOp::verify() {
4313   spirv::ImageType imageType = image().getType().cast<spirv::ImageType>();
4314   Type resultType = result().getType();
4315 
4316   spirv::Dim dim = imageType.getDim();
4317   spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
4318   spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
4319   switch (dim) {
4320   case spirv::Dim::Dim1D:
4321   case spirv::Dim::Dim2D:
4322   case spirv::Dim::Dim3D:
4323   case spirv::Dim::Cube:
4324     if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
4325         samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
4326         samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
4327       return emitError(
4328           "if Dim is 1D, 2D, 3D, or Cube, "
4329           "it must also have either an MS of 1 or a Sampled of 0 or 2");
4330     break;
4331   case spirv::Dim::Buffer:
4332   case spirv::Dim::Rect:
4333     break;
4334   default:
4335     return emitError("the Dim operand of the image type must "
4336                      "be 1D, 2D, 3D, Buffer, Cube, or Rect");
4337   }
4338 
4339   unsigned componentNumber = 0;
4340   switch (dim) {
4341   case spirv::Dim::Dim1D:
4342   case spirv::Dim::Buffer:
4343     componentNumber = 1;
4344     break;
4345   case spirv::Dim::Dim2D:
4346   case spirv::Dim::Cube:
4347   case spirv::Dim::Rect:
4348     componentNumber = 2;
4349     break;
4350   case spirv::Dim::Dim3D:
4351     componentNumber = 3;
4352     break;
4353   default:
4354     break;
4355   }
4356 
4357   if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
4358     componentNumber += 1;
4359 
4360   unsigned resultComponentNumber = 1;
4361   if (auto resultVectorType = resultType.dyn_cast<VectorType>())
4362     resultComponentNumber = resultVectorType.getNumElements();
4363 
4364   if (componentNumber != resultComponentNumber)
4365     return emitError("expected the result to have ")
4366            << componentNumber << " component(s), but found "
4367            << resultComponentNumber << " component(s)";
4368 
4369   return success();
4370 }
4371 
parsePtrAccessChainOpImpl(StringRef opName,OpAsmParser & parser,OperationState & state)4372 static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
4373                                              OpAsmParser &parser,
4374                                              OperationState &state) {
4375   OpAsmParser::UnresolvedOperand ptrInfo;
4376   SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
4377   Type type;
4378   auto loc = parser.getCurrentLocation();
4379   SmallVector<Type, 4> indicesTypes;
4380 
4381   if (parser.parseOperand(ptrInfo) ||
4382       parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
4383       parser.parseColonType(type) ||
4384       parser.resolveOperand(ptrInfo, type, state.operands))
4385     return failure();
4386 
4387   // Check that the provided indices list is not empty before parsing their
4388   // type list.
4389   if (indicesInfo.empty())
4390     return emitError(state.location) << opName << " expected element";
4391 
4392   if (parser.parseComma() || parser.parseTypeList(indicesTypes))
4393     return failure();
4394 
4395   // Check that the indices types list is not empty and that it has a one-to-one
4396   // mapping to the provided indices.
4397   if (indicesTypes.size() != indicesInfo.size())
4398     return emitError(state.location)
4399            << opName
4400            << " indices types' count must be equal to indices info count";
4401 
4402   if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
4403     return failure();
4404 
4405   auto resultType = getElementPtrType(
4406       type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
4407   if (!resultType)
4408     return failure();
4409 
4410   state.addTypes(resultType);
4411   return success();
4412 }
4413 
4414 template <typename Op>
concatElemAndIndices(Op op)4415 static auto concatElemAndIndices(Op op) {
4416   SmallVector<Value> ret(op.indices().size() + 1);
4417   ret[0] = op.element();
4418   llvm::copy(op.indices(), ret.begin() + 1);
4419   return ret;
4420 }
4421 
4422 //===----------------------------------------------------------------------===//
4423 // spv.InBoundsPtrAccessChainOp
4424 //===----------------------------------------------------------------------===//
4425 
build(OpBuilder & builder,OperationState & state,Value basePtr,Value element,ValueRange indices)4426 void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
4427                                             OperationState &state,
4428                                             Value basePtr, Value element,
4429                                             ValueRange indices) {
4430   auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4431   assert(type && "Unable to deduce return type based on basePtr and indices");
4432   build(builder, state, type, basePtr, element, indices);
4433 }
4434 
parse(OpAsmParser & parser,OperationState & state)4435 ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
4436                                                    OperationState &state) {
4437   return parsePtrAccessChainOpImpl(
4438       spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, state);
4439 }
4440 
print(OpAsmPrinter & printer)4441 void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
4442   printAccessChain(*this, concatElemAndIndices(*this), printer);
4443 }
4444 
verify()4445 LogicalResult spirv::InBoundsPtrAccessChainOp::verify() {
4446   return verifyAccessChain(*this, indices());
4447 }
4448 
4449 //===----------------------------------------------------------------------===//
4450 // spv.PtrAccessChainOp
4451 //===----------------------------------------------------------------------===//
4452 
build(OpBuilder & builder,OperationState & state,Value basePtr,Value element,ValueRange indices)4453 void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
4454                                     Value basePtr, Value element,
4455                                     ValueRange indices) {
4456   auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4457   assert(type && "Unable to deduce return type based on basePtr and indices");
4458   build(builder, state, type, basePtr, element, indices);
4459 }
4460 
parse(OpAsmParser & parser,OperationState & state)4461 ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
4462                                            OperationState &state) {
4463   return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
4464                                    parser, state);
4465 }
4466 
print(OpAsmPrinter & printer)4467 void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) {
4468   printAccessChain(*this, concatElemAndIndices(*this), printer);
4469 }
4470 
verify()4471 LogicalResult spirv::PtrAccessChainOp::verify() {
4472   return verifyAccessChain(*this, indices());
4473 }
4474 
4475 //===----------------------------------------------------------------------===//
4476 // spv.VectorTimesScalarOp
4477 //===----------------------------------------------------------------------===//
4478 
verify()4479 LogicalResult spirv::VectorTimesScalarOp::verify() {
4480   if (vector().getType() != getType())
4481     return emitOpError("vector operand and result type mismatch");
4482   auto scalarType = getType().cast<VectorType>().getElementType();
4483   if (scalar().getType() != scalarType)
4484     return emitOpError("scalar operand and result element type match");
4485   return success();
4486 }
4487 
4488 // TableGen'erated operation interfaces for querying versions, extensions, and
4489 // capabilities.
4490 #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
4491 
4492 // TablenGen'erated operation definitions.
4493 #define GET_OP_CLASSES
4494 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
4495 
4496 namespace mlir {
4497 namespace spirv {
4498 // TableGen'erated operation availability interface implementations.
4499 #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
4500 } // namespace spirv
4501 } // namespace mlir
4502