1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
17 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 using namespace mlir;
34 using namespace mlir::spirv;
35 
36 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // InlinerInterface
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
43 static inline bool containsReturn(Region &region) {
44   return llvm::any_of(region, [](Block &block) {
45     Operation *terminator = block.getTerminator();
46     return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
47   });
48 }
49 
50 namespace {
51 /// This class defines the interface for inlining within the SPIR-V dialect.
52 struct SPIRVInlinerInterface : public DialectInlinerInterface {
53   using DialectInlinerInterface::DialectInlinerInterface;
54 
55   /// All call operations within SPIRV can be inlined.
56   bool isLegalToInline(Operation *call, Operation *callable,
57                        bool wouldBeCloned) const final {
58     return true;
59   }
60 
61   /// Returns true if the given region 'src' can be inlined into the region
62   /// 'dest' that is attached to an operation registered to the current dialect.
63   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
64                        BlockAndValueMapping &) const final {
65     // Return true here when inlining into spv.func, spv.mlir.selection, and
66     // spv.mlir.loop operations.
67     auto *op = dest->getParentOp();
68     return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
69   }
70 
71   /// Returns true if the given operation 'op', that is registered to this
72   /// dialect, can be inlined into the region 'dest' that is attached to an
73   /// operation registered to the current dialect.
74   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
75                        BlockAndValueMapping &) const final {
76     // TODO: Enable inlining structured control flows with return.
77     if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
78         containsReturn(op->getRegion(0)))
79       return false;
80     // TODO: we need to filter OpKill here to avoid inlining it to
81     // a loop continue construct:
82     // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
83     // However OpKill is fragment shader specific and we don't support it yet.
84     return true;
85   }
86 
87   /// Handle the given inlined terminator by replacing it with a new operation
88   /// as necessary.
89   void handleTerminator(Operation *op, Block *newDest) const final {
90     if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
91       OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
92       op->erase();
93     } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
94       llvm_unreachable("unimplemented spv.ReturnValue in inliner");
95     }
96   }
97 
98   /// Handle the given inlined terminator by replacing it with a new operation
99   /// as necessary.
100   void handleTerminator(Operation *op,
101                         ArrayRef<Value> valuesToRepl) const final {
102     // Only spv.ReturnValue needs to be handled here.
103     auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
104     if (!retValOp)
105       return;
106 
107     // Replace the values directly with the return operands.
108     assert(valuesToRepl.size() == 1 &&
109            "spv.ReturnValue expected to only handle one result");
110     valuesToRepl.front().replaceAllUsesWith(retValOp.value());
111   }
112 };
113 } // namespace
114 
115 //===----------------------------------------------------------------------===//
116 // SPIR-V Dialect
117 //===----------------------------------------------------------------------===//
118 
119 void SPIRVDialect::initialize() {
120   registerAttributes();
121   registerTypes();
122 
123   // Add SPIR-V ops.
124   addOperations<
125 #define GET_OP_LIST
126 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
127       >();
128 
129   addInterfaces<SPIRVInlinerInterface>();
130 
131   // Allow unknown operations because SPIR-V is extensible.
132   allowUnknownOperations();
133 }
134 
135 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
136   return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Type Parsing
141 //===----------------------------------------------------------------------===//
142 
143 // Forward declarations.
144 template <typename ValTy>
145 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
146                                       DialectAsmParser &parser);
147 template <>
148 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
149                                     DialectAsmParser &parser);
150 
151 template <>
152 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
153                                             DialectAsmParser &parser);
154 
155 static Type parseAndVerifyType(SPIRVDialect const &dialect,
156                                DialectAsmParser &parser) {
157   Type type;
158   llvm::SMLoc typeLoc = parser.getCurrentLocation();
159   if (parser.parseType(type))
160     return Type();
161 
162   // Allow SPIR-V dialect types
163   if (&type.getDialect() == &dialect)
164     return type;
165 
166   // Check other allowed types
167   if (auto t = type.dyn_cast<FloatType>()) {
168     if (type.isBF16()) {
169       parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
170       return Type();
171     }
172   } else if (auto t = type.dyn_cast<IntegerType>()) {
173     if (!ScalarType::isValid(t)) {
174       parser.emitError(typeLoc,
175                        "only 1/8/16/32/64-bit integer type allowed but found ")
176           << type;
177       return Type();
178     }
179   } else if (auto t = type.dyn_cast<VectorType>()) {
180     if (t.getRank() != 1) {
181       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
182       return Type();
183     }
184     if (t.getNumElements() > 4) {
185       parser.emitError(
186           typeLoc, "vector length has to be less than or equal to 4 but found ")
187           << t.getNumElements();
188       return Type();
189     }
190   } else {
191     parser.emitError(typeLoc, "cannot use ")
192         << type << " to compose SPIR-V types";
193     return Type();
194   }
195 
196   return type;
197 }
198 
199 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
200                                      DialectAsmParser &parser) {
201   Type type;
202   llvm::SMLoc typeLoc = parser.getCurrentLocation();
203   if (parser.parseType(type))
204     return Type();
205 
206   if (auto t = type.dyn_cast<VectorType>()) {
207     if (t.getRank() != 1) {
208       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
209       return Type();
210     }
211     if (t.getNumElements() > 4 || t.getNumElements() < 2) {
212       parser.emitError(typeLoc,
213                        "matrix columns size has to be less than or equal "
214                        "to 4 and greater than or equal 2, but found ")
215           << t.getNumElements();
216       return Type();
217     }
218 
219     if (!t.getElementType().isa<FloatType>()) {
220       parser.emitError(typeLoc, "matrix columns' elements must be of "
221                                 "Float type, got ")
222           << t.getElementType();
223       return Type();
224     }
225   } else {
226     parser.emitError(typeLoc, "matrix must be composed using vector "
227                               "type, got ")
228         << type;
229     return Type();
230   }
231 
232   return type;
233 }
234 
235 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
236                                            DialectAsmParser &parser) {
237   Type type;
238   llvm::SMLoc typeLoc = parser.getCurrentLocation();
239   if (parser.parseType(type))
240     return Type();
241 
242   if (!type.isa<ImageType>()) {
243     parser.emitError(typeLoc,
244                      "sampled image must be composed using image type, got ")
245         << type;
246     return Type();
247   }
248 
249   return type;
250 }
251 
252 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
253 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
254 /// missing.
255 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
256                                               DialectAsmParser &parser,
257                                               unsigned &stride) {
258   if (failed(parser.parseOptionalComma())) {
259     stride = 0;
260     return success();
261   }
262 
263   if (parser.parseKeyword("stride") || parser.parseEqual())
264     return failure();
265 
266   llvm::SMLoc strideLoc = parser.getCurrentLocation();
267   Optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
268   if (!optStride)
269     return failure();
270 
271   if (!(stride = optStride.getValue())) {
272     parser.emitError(strideLoc, "ArrayStride must be greater than zero");
273     return failure();
274   }
275   return success();
276 }
277 
278 // element-type ::= integer-type
279 //                | floating-point-type
280 //                | vector-type
281 //                | spirv-type
282 //
283 // array-type ::= `!spv.array` `<` integer-literal `x` element-type
284 //                (`,` `stride` `=` integer-literal)? `>`
285 static Type parseArrayType(SPIRVDialect const &dialect,
286                            DialectAsmParser &parser) {
287   if (parser.parseLess())
288     return Type();
289 
290   SmallVector<int64_t, 1> countDims;
291   llvm::SMLoc countLoc = parser.getCurrentLocation();
292   if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
293     return Type();
294   if (countDims.size() != 1) {
295     parser.emitError(countLoc,
296                      "expected single integer for array element count");
297     return Type();
298   }
299 
300   // According to the SPIR-V spec:
301   // "Length is the number of elements in the array. It must be at least 1."
302   int64_t count = countDims[0];
303   if (count == 0) {
304     parser.emitError(countLoc, "expected array length greater than 0");
305     return Type();
306   }
307 
308   Type elementType = parseAndVerifyType(dialect, parser);
309   if (!elementType)
310     return Type();
311 
312   unsigned stride = 0;
313   if (failed(parseOptionalArrayStride(dialect, parser, stride)))
314     return Type();
315 
316   if (parser.parseGreater())
317     return Type();
318   return ArrayType::get(elementType, count, stride);
319 }
320 
321 // cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
322 //                                                   rows ',' columns>`
323 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
324                                        DialectAsmParser &parser) {
325   if (parser.parseLess())
326     return Type();
327 
328   SmallVector<int64_t, 2> dims;
329   llvm::SMLoc countLoc = parser.getCurrentLocation();
330   if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
331     return Type();
332 
333   if (dims.size() != 2) {
334     parser.emitError(countLoc, "expected rows and columns size");
335     return Type();
336   }
337 
338   auto elementTy = parseAndVerifyType(dialect, parser);
339   if (!elementTy)
340     return Type();
341 
342   Scope scope;
343   if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
344     return Type();
345 
346   if (parser.parseGreater())
347     return Type();
348   return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
349 }
350 
351 // TODO: Reorder methods to be utilities first and parse*Type
352 // methods in alphabetical order
353 //
354 // storage-class ::= `UniformConstant`
355 //                 | `Uniform`
356 //                 | `Workgroup`
357 //                 | <and other storage classes...>
358 //
359 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
360 static Type parsePointerType(SPIRVDialect const &dialect,
361                              DialectAsmParser &parser) {
362   if (parser.parseLess())
363     return Type();
364 
365   auto pointeeType = parseAndVerifyType(dialect, parser);
366   if (!pointeeType)
367     return Type();
368 
369   StringRef storageClassSpec;
370   llvm::SMLoc storageClassLoc = parser.getCurrentLocation();
371   if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
372     return Type();
373 
374   auto storageClass = symbolizeStorageClass(storageClassSpec);
375   if (!storageClass) {
376     parser.emitError(storageClassLoc, "unknown storage class: ")
377         << storageClassSpec;
378     return Type();
379   }
380   if (parser.parseGreater())
381     return Type();
382   return PointerType::get(pointeeType, *storageClass);
383 }
384 
385 // runtime-array-type ::= `!spv.rtarray` `<` element-type
386 //                        (`,` `stride` `=` integer-literal)? `>`
387 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
388                                   DialectAsmParser &parser) {
389   if (parser.parseLess())
390     return Type();
391 
392   Type elementType = parseAndVerifyType(dialect, parser);
393   if (!elementType)
394     return Type();
395 
396   unsigned stride = 0;
397   if (failed(parseOptionalArrayStride(dialect, parser, stride)))
398     return Type();
399 
400   if (parser.parseGreater())
401     return Type();
402   return RuntimeArrayType::get(elementType, stride);
403 }
404 
405 // matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
406 static Type parseMatrixType(SPIRVDialect const &dialect,
407                             DialectAsmParser &parser) {
408   if (parser.parseLess())
409     return Type();
410 
411   SmallVector<int64_t, 1> countDims;
412   llvm::SMLoc countLoc = parser.getCurrentLocation();
413   if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
414     return Type();
415   if (countDims.size() != 1) {
416     parser.emitError(countLoc, "expected single unsigned "
417                                "integer for number of columns");
418     return Type();
419   }
420 
421   int64_t columnCount = countDims[0];
422   // According to the specification, Matrices can have 2, 3, or 4 columns
423   if (columnCount < 2 || columnCount > 4) {
424     parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
425                                "columns");
426     return Type();
427   }
428 
429   Type columnType = parseAndVerifyMatrixType(dialect, parser);
430   if (!columnType)
431     return Type();
432 
433   if (parser.parseGreater())
434     return Type();
435 
436   return MatrixType::get(columnType, columnCount);
437 }
438 
439 // Specialize this function to parse each of the parameters that define an
440 // ImageType. By default it assumes this is an enum type.
441 template <typename ValTy>
442 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
443                                       DialectAsmParser &parser) {
444   StringRef enumSpec;
445   llvm::SMLoc enumLoc = parser.getCurrentLocation();
446   if (parser.parseKeyword(&enumSpec)) {
447     return llvm::None;
448   }
449 
450   auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
451   if (!val)
452     parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
453   return val;
454 }
455 
456 template <>
457 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
458                                     DialectAsmParser &parser) {
459   // TODO: Further verify that the element type can be sampled
460   auto ty = parseAndVerifyType(dialect, parser);
461   if (!ty)
462     return llvm::None;
463   return ty;
464 }
465 
466 template <typename IntTy>
467 static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
468                                              DialectAsmParser &parser) {
469   IntTy offsetVal = std::numeric_limits<IntTy>::max();
470   if (parser.parseInteger(offsetVal))
471     return llvm::None;
472   return offsetVal;
473 }
474 
475 template <>
476 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
477                                             DialectAsmParser &parser) {
478   return parseAndVerifyInteger<unsigned>(dialect, parser);
479 }
480 
481 namespace {
482 // Functor object to parse a comma separated list of specs. The function
483 // parseAndVerify does the actual parsing and verification of individual
484 // elements. This is a functor since parsing the last element of the list
485 // (termination condition) needs partial specialization.
486 template <typename ParseType, typename... Args> struct ParseCommaSeparatedList {
487   Optional<std::tuple<ParseType, Args...>>
488   operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
489     auto parseVal = parseAndVerify<ParseType>(dialect, parser);
490     if (!parseVal)
491       return llvm::None;
492 
493     auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
494     if (numArgs != 0 && failed(parser.parseComma()))
495       return llvm::None;
496     auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
497     if (!remainingValues)
498       return llvm::None;
499     return std::tuple_cat(std::tuple<ParseType>(parseVal.getValue()),
500                           remainingValues.getValue());
501   }
502 };
503 
504 // Partial specialization of the function to parse a comma separated list of
505 // specs to parse the last element of the list.
506 template <typename ParseType> struct ParseCommaSeparatedList<ParseType> {
507   Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
508                                              DialectAsmParser &parser) const {
509     if (auto value = parseAndVerify<ParseType>(dialect, parser))
510       return std::tuple<ParseType>(value.getValue());
511     return llvm::None;
512   }
513 };
514 } // namespace
515 
516 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
517 //
518 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
519 //
520 // arrayed-info ::= `NonArrayed` | `Arrayed`
521 //
522 // sampling-info ::= `SingleSampled` | `MultiSampled`
523 //
524 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` |  `NoSampler`
525 //
526 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
527 //
528 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
529 //                              arrayed-info `,` sampling-info `,`
530 //                              sampler-use-info `,` format `>`
531 static Type parseImageType(SPIRVDialect const &dialect,
532                            DialectAsmParser &parser) {
533   if (parser.parseLess())
534     return Type();
535 
536   auto value =
537       ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
538                               ImageSamplingInfo, ImageSamplerUseInfo,
539                               ImageFormat>{}(dialect, parser);
540   if (!value)
541     return Type();
542 
543   if (parser.parseGreater())
544     return Type();
545   return ImageType::get(value.getValue());
546 }
547 
548 // sampledImage-type :: = `!spv.sampledImage<` image-type `>`
549 static Type parseSampledImageType(SPIRVDialect const &dialect,
550                                   DialectAsmParser &parser) {
551   if (parser.parseLess())
552     return Type();
553 
554   Type parsedType = parseAndVerifySampledImageType(dialect, parser);
555   if (!parsedType)
556     return Type();
557 
558   if (parser.parseGreater())
559     return Type();
560   return SampledImageType::get(parsedType);
561 }
562 
563 // Parse decorations associated with a member.
564 static ParseResult parseStructMemberDecorations(
565     SPIRVDialect const &dialect, DialectAsmParser &parser,
566     ArrayRef<Type> memberTypes,
567     SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
568     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
569 
570   // Check if the first element is offset.
571   llvm::SMLoc offsetLoc = parser.getCurrentLocation();
572   StructType::OffsetInfo offset = 0;
573   OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
574   if (offsetParseResult.hasValue()) {
575     if (failed(*offsetParseResult))
576       return failure();
577 
578     if (offsetInfo.size() != memberTypes.size() - 1) {
579       return parser.emitError(offsetLoc,
580                               "offset specification must be given for "
581                               "all members");
582     }
583     offsetInfo.push_back(offset);
584   }
585 
586   // Check for no spirv::Decorations.
587   if (succeeded(parser.parseOptionalRSquare()))
588     return success();
589 
590   // If there was an offset, make sure to parse the comma.
591   if (offsetParseResult.hasValue() && parser.parseComma())
592     return failure();
593 
594   // Check for spirv::Decorations.
595   do {
596     auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
597     if (!memberDecoration)
598       return failure();
599 
600     // Parse member decoration value if it exists.
601     if (succeeded(parser.parseOptionalEqual())) {
602       auto memberDecorationValue =
603           parseAndVerifyInteger<uint32_t>(dialect, parser);
604 
605       if (!memberDecorationValue)
606         return failure();
607 
608       memberDecorationInfo.emplace_back(
609           static_cast<uint32_t>(memberTypes.size() - 1), 1,
610           memberDecoration.getValue(), memberDecorationValue.getValue());
611     } else {
612       memberDecorationInfo.emplace_back(
613           static_cast<uint32_t>(memberTypes.size() - 1), 0,
614           memberDecoration.getValue(), 0);
615     }
616 
617   } while (succeeded(parser.parseOptionalComma()));
618 
619   return parser.parseRSquare();
620 }
621 
622 // struct-member-decoration ::= integer-literal? spirv-decoration*
623 // struct-type ::=
624 //             `!spv.struct<` (id `,`)?
625 //                          `(`
626 //                            (spirv-type (`[` struct-member-decoration `]`)?)*
627 //                          `)>`
628 static Type parseStructType(SPIRVDialect const &dialect,
629                             DialectAsmParser &parser) {
630   // TODO: This function is quite lengthy. Break it down into smaller chunks.
631 
632   // To properly resolve recursive references while parsing recursive struct
633   // types, we need to maintain a list of enclosing struct type names. This set
634   // maintains the names of struct types in which the type we are about to parse
635   // is nested.
636   //
637   // Note: This has to be thread_local to enable multiple threads to safely
638   // parse concurrently.
639   thread_local SetVector<StringRef> structContext;
640 
641   static auto removeIdentifierAndFail = [](SetVector<StringRef> &structContext,
642                                            StringRef identifier) {
643     if (!identifier.empty())
644       structContext.remove(identifier);
645 
646     return Type();
647   };
648 
649   if (parser.parseLess())
650     return Type();
651 
652   StringRef identifier;
653 
654   // Check if this is an identified struct type.
655   if (succeeded(parser.parseOptionalKeyword(&identifier))) {
656     // Check if this is a possible recursive reference.
657     if (succeeded(parser.parseOptionalGreater())) {
658       if (structContext.count(identifier) == 0) {
659         parser.emitError(
660             parser.getNameLoc(),
661             "recursive struct reference not nested in struct definition");
662 
663         return Type();
664       }
665 
666       return StructType::getIdentified(dialect.getContext(), identifier);
667     }
668 
669     if (failed(parser.parseComma()))
670       return Type();
671 
672     if (structContext.count(identifier) != 0) {
673       parser.emitError(parser.getNameLoc(),
674                        "identifier already used for an enclosing struct");
675 
676       return removeIdentifierAndFail(structContext, identifier);
677     }
678 
679     structContext.insert(identifier);
680   }
681 
682   if (failed(parser.parseLParen()))
683     return removeIdentifierAndFail(structContext, identifier);
684 
685   if (succeeded(parser.parseOptionalRParen()) &&
686       succeeded(parser.parseOptionalGreater())) {
687     if (!identifier.empty())
688       structContext.remove(identifier);
689 
690     return StructType::getEmpty(dialect.getContext(), identifier);
691   }
692 
693   StructType idStructTy;
694 
695   if (!identifier.empty())
696     idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
697 
698   SmallVector<Type, 4> memberTypes;
699   SmallVector<StructType::OffsetInfo, 4> offsetInfo;
700   SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
701 
702   do {
703     Type memberType;
704     if (parser.parseType(memberType))
705       return removeIdentifierAndFail(structContext, identifier);
706     memberTypes.push_back(memberType);
707 
708     if (succeeded(parser.parseOptionalLSquare()))
709       if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
710                                        memberDecorationInfo))
711         return removeIdentifierAndFail(structContext, identifier);
712   } while (succeeded(parser.parseOptionalComma()));
713 
714   if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
715     parser.emitError(parser.getNameLoc(),
716                      "offset specification must be given for all members");
717     return removeIdentifierAndFail(structContext, identifier);
718   }
719 
720   if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
721     return removeIdentifierAndFail(structContext, identifier);
722 
723   if (!identifier.empty()) {
724     if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
725                                      memberDecorationInfo)))
726       return Type();
727 
728     structContext.remove(identifier);
729     return idStructTy;
730   }
731 
732   return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
733 }
734 
735 // spirv-type ::= array-type
736 //              | element-type
737 //              | image-type
738 //              | pointer-type
739 //              | runtime-array-type
740 //              | sampled-image-type
741 //              | struct-type
742 Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
743   StringRef keyword;
744   if (parser.parseKeyword(&keyword))
745     return Type();
746 
747   if (keyword == "array")
748     return parseArrayType(*this, parser);
749   if (keyword == "coopmatrix")
750     return parseCooperativeMatrixType(*this, parser);
751   if (keyword == "image")
752     return parseImageType(*this, parser);
753   if (keyword == "ptr")
754     return parsePointerType(*this, parser);
755   if (keyword == "rtarray")
756     return parseRuntimeArrayType(*this, parser);
757   if (keyword == "sampled_image")
758     return parseSampledImageType(*this, parser);
759   if (keyword == "struct")
760     return parseStructType(*this, parser);
761   if (keyword == "matrix")
762     return parseMatrixType(*this, parser);
763   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
764   return Type();
765 }
766 
767 //===----------------------------------------------------------------------===//
768 // Type Printing
769 //===----------------------------------------------------------------------===//
770 
771 static void print(ArrayType type, DialectAsmPrinter &os) {
772   os << "array<" << type.getNumElements() << " x " << type.getElementType();
773   if (unsigned stride = type.getArrayStride())
774     os << ", stride=" << stride;
775   os << ">";
776 }
777 
778 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
779   os << "rtarray<" << type.getElementType();
780   if (unsigned stride = type.getArrayStride())
781     os << ", stride=" << stride;
782   os << ">";
783 }
784 
785 static void print(PointerType type, DialectAsmPrinter &os) {
786   os << "ptr<" << type.getPointeeType() << ", "
787      << stringifyStorageClass(type.getStorageClass()) << ">";
788 }
789 
790 static void print(ImageType type, DialectAsmPrinter &os) {
791   os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
792      << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
793      << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
794      << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
795      << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
796      << stringifyImageFormat(type.getImageFormat()) << ">";
797 }
798 
799 static void print(SampledImageType type, DialectAsmPrinter &os) {
800   os << "sampled_image<" << type.getImageType() << ">";
801 }
802 
803 static void print(StructType type, DialectAsmPrinter &os) {
804   thread_local SetVector<StringRef> structContext;
805 
806   os << "struct<";
807 
808   if (type.isIdentified()) {
809     os << type.getIdentifier();
810 
811     if (structContext.count(type.getIdentifier())) {
812       os << ">";
813       return;
814     }
815 
816     os << ", ";
817     structContext.insert(type.getIdentifier());
818   }
819 
820   os << "(";
821 
822   auto printMember = [&](unsigned i) {
823     os << type.getElementType(i);
824     SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
825     type.getMemberDecorations(i, decorations);
826     if (type.hasOffset() || !decorations.empty()) {
827       os << " [";
828       if (type.hasOffset()) {
829         os << type.getMemberOffset(i);
830         if (!decorations.empty())
831           os << ", ";
832       }
833       auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
834         os << stringifyDecoration(decoration.decoration);
835         if (decoration.hasValue) {
836           os << "=" << decoration.decorationValue;
837         }
838       };
839       llvm::interleaveComma(decorations, os, eachFn);
840       os << "]";
841     }
842   };
843   llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
844                         printMember);
845   os << ")>";
846 
847   if (type.isIdentified())
848     structContext.remove(type.getIdentifier());
849 }
850 
851 static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
852   os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
853   os << type.getElementType() << ", " << stringifyScope(type.getScope());
854   os << ">";
855 }
856 
857 static void print(MatrixType type, DialectAsmPrinter &os) {
858   os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
859   os << ">";
860 }
861 
862 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
863   TypeSwitch<Type>(type)
864       .Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
865             ImageType, SampledImageType, StructType, MatrixType>(
866           [&](auto type) { print(type, os); })
867       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
868 }
869 
870 //===----------------------------------------------------------------------===//
871 // Attribute Parsing
872 //===----------------------------------------------------------------------===//
873 
874 /// Parses a comma-separated list of keywords, invokes `processKeyword` on each
875 /// of the parsed keyword, and returns failure if any error occurs.
876 static ParseResult parseKeywordList(
877     DialectAsmParser &parser,
878     function_ref<LogicalResult(llvm::SMLoc, StringRef)> processKeyword) {
879   if (parser.parseLSquare())
880     return failure();
881 
882   // Special case for empty list.
883   if (succeeded(parser.parseOptionalRSquare()))
884     return success();
885 
886   // Keep parsing the keyword and an optional comma following it. If the comma
887   // is successfully parsed, then we have more keywords to parse.
888   do {
889     auto loc = parser.getCurrentLocation();
890     StringRef keyword;
891     if (parser.parseKeyword(&keyword) || failed(processKeyword(loc, keyword)))
892       return failure();
893   } while (succeeded(parser.parseOptionalComma()));
894 
895   if (parser.parseRSquare())
896     return failure();
897 
898   return success();
899 }
900 
901 /// Parses a spirv::InterfaceVarABIAttr.
902 static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) {
903   if (parser.parseLess())
904     return {};
905 
906   Builder &builder = parser.getBuilder();
907 
908   if (parser.parseLParen())
909     return {};
910 
911   IntegerAttr descriptorSetAttr;
912   {
913     auto loc = parser.getCurrentLocation();
914     uint32_t descriptorSet = 0;
915     auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet);
916 
917     if (!descriptorSetParseResult.hasValue() ||
918         failed(*descriptorSetParseResult)) {
919       parser.emitError(loc, "missing descriptor set");
920       return {};
921     }
922     descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet);
923   }
924 
925   if (parser.parseComma())
926     return {};
927 
928   IntegerAttr bindingAttr;
929   {
930     auto loc = parser.getCurrentLocation();
931     uint32_t binding = 0;
932     auto bindingParseResult = parser.parseOptionalInteger(binding);
933 
934     if (!bindingParseResult.hasValue() || failed(*bindingParseResult)) {
935       parser.emitError(loc, "missing binding");
936       return {};
937     }
938     bindingAttr = builder.getI32IntegerAttr(binding);
939   }
940 
941   if (parser.parseRParen())
942     return {};
943 
944   IntegerAttr storageClassAttr;
945   {
946     if (succeeded(parser.parseOptionalComma())) {
947       auto loc = parser.getCurrentLocation();
948       StringRef storageClass;
949       if (parser.parseKeyword(&storageClass))
950         return {};
951 
952       if (auto storageClassSymbol =
953               spirv::symbolizeStorageClass(storageClass)) {
954         storageClassAttr = builder.getI32IntegerAttr(
955             static_cast<uint32_t>(*storageClassSymbol));
956       } else {
957         parser.emitError(loc, "unknown storage class: ") << storageClass;
958         return {};
959       }
960     }
961   }
962 
963   if (parser.parseGreater())
964     return {};
965 
966   return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr,
967                                          storageClassAttr);
968 }
969 
970 static Attribute parseVerCapExtAttr(DialectAsmParser &parser) {
971   if (parser.parseLess())
972     return {};
973 
974   Builder &builder = parser.getBuilder();
975 
976   IntegerAttr versionAttr;
977   {
978     auto loc = parser.getCurrentLocation();
979     StringRef version;
980     if (parser.parseKeyword(&version) || parser.parseComma())
981       return {};
982 
983     if (auto versionSymbol = spirv::symbolizeVersion(version)) {
984       versionAttr =
985           builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol));
986     } else {
987       parser.emitError(loc, "unknown version: ") << version;
988       return {};
989     }
990   }
991 
992   ArrayAttr capabilitiesAttr;
993   {
994     SmallVector<Attribute, 4> capabilities;
995     llvm::SMLoc errorloc;
996     StringRef errorKeyword;
997 
998     auto processCapability = [&](llvm::SMLoc loc, StringRef capability) {
999       if (auto capSymbol = spirv::symbolizeCapability(capability)) {
1000         capabilities.push_back(
1001             builder.getI32IntegerAttr(static_cast<uint32_t>(*capSymbol)));
1002         return success();
1003       }
1004       return errorloc = loc, errorKeyword = capability, failure();
1005     };
1006     if (parseKeywordList(parser, processCapability) || parser.parseComma()) {
1007       if (!errorKeyword.empty())
1008         parser.emitError(errorloc, "unknown capability: ") << errorKeyword;
1009       return {};
1010     }
1011 
1012     capabilitiesAttr = builder.getArrayAttr(capabilities);
1013   }
1014 
1015   ArrayAttr extensionsAttr;
1016   {
1017     SmallVector<Attribute, 1> extensions;
1018     llvm::SMLoc errorloc;
1019     StringRef errorKeyword;
1020 
1021     auto processExtension = [&](llvm::SMLoc loc, StringRef extension) {
1022       if (spirv::symbolizeExtension(extension)) {
1023         extensions.push_back(builder.getStringAttr(extension));
1024         return success();
1025       }
1026       return errorloc = loc, errorKeyword = extension, failure();
1027     };
1028     if (parseKeywordList(parser, processExtension)) {
1029       if (!errorKeyword.empty())
1030         parser.emitError(errorloc, "unknown extension: ") << errorKeyword;
1031       return {};
1032     }
1033 
1034     extensionsAttr = builder.getArrayAttr(extensions);
1035   }
1036 
1037   if (parser.parseGreater())
1038     return {};
1039 
1040   return spirv::VerCapExtAttr::get(versionAttr, capabilitiesAttr,
1041                                    extensionsAttr);
1042 }
1043 
1044 /// Parses a spirv::TargetEnvAttr.
1045 static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
1046   if (parser.parseLess())
1047     return {};
1048 
1049   spirv::VerCapExtAttr tripleAttr;
1050   if (parser.parseAttribute(tripleAttr) || parser.parseComma())
1051     return {};
1052 
1053   // Parse [vendor[:device-type[:device-id]]]
1054   Vendor vendorID = Vendor::Unknown;
1055   DeviceType deviceType = DeviceType::Unknown;
1056   uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID;
1057   {
1058     auto loc = parser.getCurrentLocation();
1059     StringRef vendorStr;
1060     if (succeeded(parser.parseOptionalKeyword(&vendorStr))) {
1061       if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) {
1062         vendorID = *vendorSymbol;
1063       } else {
1064         parser.emitError(loc, "unknown vendor: ") << vendorStr;
1065       }
1066 
1067       if (succeeded(parser.parseOptionalColon())) {
1068         loc = parser.getCurrentLocation();
1069         StringRef deviceTypeStr;
1070         if (parser.parseKeyword(&deviceTypeStr))
1071           return {};
1072         if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) {
1073           deviceType = *deviceTypeSymbol;
1074         } else {
1075           parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
1076         }
1077 
1078         if (succeeded(parser.parseOptionalColon())) {
1079           loc = parser.getCurrentLocation();
1080           if (parser.parseInteger(deviceID))
1081             return {};
1082         }
1083       }
1084       if (parser.parseComma())
1085         return {};
1086     }
1087   }
1088 
1089   DictionaryAttr limitsAttr;
1090   {
1091     auto loc = parser.getCurrentLocation();
1092     if (parser.parseAttribute(limitsAttr))
1093       return {};
1094 
1095     if (!limitsAttr.isa<spirv::ResourceLimitsAttr>()) {
1096       parser.emitError(
1097           loc,
1098           "limits must be a dictionary attribute containing two 32-bit integer "
1099           "attributes 'max_compute_workgroup_invocations' and "
1100           "'max_compute_workgroup_size'");
1101       return {};
1102     }
1103   }
1104 
1105   if (parser.parseGreater())
1106     return {};
1107 
1108   return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID,
1109                                    limitsAttr);
1110 }
1111 
1112 Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
1113                                        Type type) const {
1114   // SPIR-V attributes are dictionaries so they do not have type.
1115   if (type) {
1116     parser.emitError(parser.getNameLoc(), "unexpected type");
1117     return {};
1118   }
1119 
1120   // Parse the kind keyword first.
1121   StringRef attrKind;
1122   if (parser.parseKeyword(&attrKind))
1123     return {};
1124 
1125   if (attrKind == spirv::TargetEnvAttr::getKindName())
1126     return parseTargetEnvAttr(parser);
1127   if (attrKind == spirv::VerCapExtAttr::getKindName())
1128     return parseVerCapExtAttr(parser);
1129   if (attrKind == spirv::InterfaceVarABIAttr::getKindName())
1130     return parseInterfaceVarABIAttr(parser);
1131 
1132   parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ")
1133       << attrKind;
1134   return {};
1135 }
1136 
1137 //===----------------------------------------------------------------------===//
1138 // Attribute Printing
1139 //===----------------------------------------------------------------------===//
1140 
1141 static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
1142   auto &os = printer.getStream();
1143   printer << spirv::VerCapExtAttr::getKindName() << "<"
1144           << spirv::stringifyVersion(triple.getVersion()) << ", [";
1145   llvm::interleaveComma(
1146       triple.getCapabilities(), os,
1147       [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); });
1148   printer << "], [";
1149   llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) {
1150     os << attr.cast<StringAttr>().getValue();
1151   });
1152   printer << "]>";
1153 }
1154 
1155 static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
1156   printer << spirv::TargetEnvAttr::getKindName() << "<#spv.";
1157   print(targetEnv.getTripleAttr(), printer);
1158   spirv::Vendor vendorID = targetEnv.getVendorID();
1159   spirv::DeviceType deviceType = targetEnv.getDeviceType();
1160   uint32_t deviceID = targetEnv.getDeviceID();
1161   if (vendorID != spirv::Vendor::Unknown) {
1162     printer << ", " << spirv::stringifyVendor(vendorID);
1163     if (deviceType != spirv::DeviceType::Unknown) {
1164       printer << ":" << spirv::stringifyDeviceType(deviceType);
1165       if (deviceID != spirv::TargetEnvAttr::kUnknownDeviceID)
1166         printer << ":" << deviceID;
1167     }
1168   }
1169   printer << ", " << targetEnv.getResourceLimits() << ">";
1170 }
1171 
1172 static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,
1173                   DialectAsmPrinter &printer) {
1174   printer << spirv::InterfaceVarABIAttr::getKindName() << "<("
1175           << interfaceVarABIAttr.getDescriptorSet() << ", "
1176           << interfaceVarABIAttr.getBinding() << ")";
1177   auto storageClass = interfaceVarABIAttr.getStorageClass();
1178   if (storageClass)
1179     printer << ", " << spirv::stringifyStorageClass(*storageClass);
1180   printer << ">";
1181 }
1182 
1183 void SPIRVDialect::printAttribute(Attribute attr,
1184                                   DialectAsmPrinter &printer) const {
1185   if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>())
1186     print(targetEnv, printer);
1187   else if (auto vceAttr = attr.dyn_cast<VerCapExtAttr>())
1188     print(vceAttr, printer);
1189   else if (auto interfaceVarABIAttr = attr.dyn_cast<InterfaceVarABIAttr>())
1190     print(interfaceVarABIAttr, printer);
1191   else
1192     llvm_unreachable("unhandled SPIR-V attribute kind");
1193 }
1194 
1195 //===----------------------------------------------------------------------===//
1196 // Constant
1197 //===----------------------------------------------------------------------===//
1198 
1199 Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
1200                                              Attribute value, Type type,
1201                                              Location loc) {
1202   if (!spirv::ConstantOp::isBuildableWith(type))
1203     return nullptr;
1204 
1205   return builder.create<spirv::ConstantOp>(loc, type, value);
1206 }
1207 
1208 //===----------------------------------------------------------------------===//
1209 // Shader Interface ABI
1210 //===----------------------------------------------------------------------===//
1211 
1212 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
1213                                                      NamedAttribute attribute) {
1214   StringRef symbol = attribute.getName().strref();
1215   Attribute attr = attribute.getValue();
1216 
1217   // TODO: figure out a way to generate the description from the
1218   // StructAttr definition.
1219   if (symbol == spirv::getEntryPointABIAttrName()) {
1220     if (!attr.isa<spirv::EntryPointABIAttr>())
1221       return op->emitError("'")
1222              << symbol
1223              << "' attribute must be a dictionary attribute containing one "
1224                 "32-bit integer elements attribute: 'local_size'";
1225   } else if (symbol == spirv::getTargetEnvAttrName()) {
1226     if (!attr.isa<spirv::TargetEnvAttr>())
1227       return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
1228   } else {
1229     return op->emitError("found unsupported '")
1230            << symbol << "' attribute on operation";
1231   }
1232 
1233   return success();
1234 }
1235 
1236 /// Verifies the given SPIR-V `attribute` attached to a value of the given
1237 /// `valueType` is valid.
1238 static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
1239                                            NamedAttribute attribute) {
1240   StringRef symbol = attribute.getName().strref();
1241   Attribute attr = attribute.getValue();
1242 
1243   if (symbol != spirv::getInterfaceVarABIAttrName())
1244     return emitError(loc, "found unsupported '")
1245            << symbol << "' attribute on region argument";
1246 
1247   auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
1248   if (!varABIAttr)
1249     return emitError(loc, "'")
1250            << symbol << "' must be a spirv::InterfaceVarABIAttr";
1251 
1252   if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
1253     return emitError(loc, "'") << symbol
1254                                << "' attribute cannot specify storage class "
1255                                   "when attaching to a non-scalar value";
1256 
1257   return success();
1258 }
1259 
1260 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
1261                                                      unsigned regionIndex,
1262                                                      unsigned argIndex,
1263                                                      NamedAttribute attribute) {
1264   return verifyRegionAttribute(
1265       op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
1266       attribute);
1267 }
1268 
1269 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
1270     Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
1271     NamedAttribute attribute) {
1272   return op->emitError("cannot attach SPIR-V attributes to region result");
1273 }
1274