1 //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the SPIR-V dialect in MLIR.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
17 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/BuiltinTypes.h"
20 #include "mlir/IR/DialectImplementation.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser/Parser.h"
23 #include "mlir/Transforms/InliningUtils.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringMap.h"
29 #include "llvm/ADT/StringSwitch.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/raw_ostream.h"
32 
33 using namespace mlir;
34 using namespace mlir::spirv;
35 
36 #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc"
37 
38 //===----------------------------------------------------------------------===//
39 // InlinerInterface
40 //===----------------------------------------------------------------------===//
41 
42 /// Returns true if the given region contains spv.Return or spv.ReturnValue ops.
containsReturn(Region & region)43 static inline bool containsReturn(Region &region) {
44   return llvm::any_of(region, [](Block &block) {
45     Operation *terminator = block.getTerminator();
46     return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator);
47   });
48 }
49 
50 namespace {
51 /// This class defines the interface for inlining within the SPIR-V dialect.
52 struct SPIRVInlinerInterface : public DialectInlinerInterface {
53   using DialectInlinerInterface::DialectInlinerInterface;
54 
55   /// All call operations within SPIRV can be inlined.
isLegalToInline__anon52aeee350211::SPIRVInlinerInterface56   bool isLegalToInline(Operation *call, Operation *callable,
57                        bool wouldBeCloned) const final {
58     return true;
59   }
60 
61   /// Returns true if the given region 'src' can be inlined into the region
62   /// 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anon52aeee350211::SPIRVInlinerInterface63   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
64                        BlockAndValueMapping &) const final {
65     // Return true here when inlining into spv.func, spv.mlir.selection, and
66     // spv.mlir.loop operations.
67     auto *op = dest->getParentOp();
68     return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op);
69   }
70 
71   /// Returns true if the given operation 'op', that is registered to this
72   /// dialect, can be inlined into the region 'dest' that is attached to an
73   /// operation registered to the current dialect.
isLegalToInline__anon52aeee350211::SPIRVInlinerInterface74   bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
75                        BlockAndValueMapping &) const final {
76     // TODO: Enable inlining structured control flows with return.
77     if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) &&
78         containsReturn(op->getRegion(0)))
79       return false;
80     // TODO: we need to filter OpKill here to avoid inlining it to
81     // a loop continue construct:
82     // https://github.com/KhronosGroup/SPIRV-Headers/issues/86
83     // However OpKill is fragment shader specific and we don't support it yet.
84     return true;
85   }
86 
87   /// Handle the given inlined terminator by replacing it with a new operation
88   /// as necessary.
handleTerminator__anon52aeee350211::SPIRVInlinerInterface89   void handleTerminator(Operation *op, Block *newDest) const final {
90     if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) {
91       OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest);
92       op->erase();
93     } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) {
94       llvm_unreachable("unimplemented spv.ReturnValue in inliner");
95     }
96   }
97 
98   /// Handle the given inlined terminator by replacing it with a new operation
99   /// as necessary.
handleTerminator__anon52aeee350211::SPIRVInlinerInterface100   void handleTerminator(Operation *op,
101                         ArrayRef<Value> valuesToRepl) const final {
102     // Only spv.ReturnValue needs to be handled here.
103     auto retValOp = dyn_cast<spirv::ReturnValueOp>(op);
104     if (!retValOp)
105       return;
106 
107     // Replace the values directly with the return operands.
108     assert(valuesToRepl.size() == 1 &&
109            "spv.ReturnValue expected to only handle one result");
110     valuesToRepl.front().replaceAllUsesWith(retValOp.value());
111   }
112 };
113 } // namespace
114 
115 //===----------------------------------------------------------------------===//
116 // SPIR-V Dialect
117 //===----------------------------------------------------------------------===//
118 
initialize()119 void SPIRVDialect::initialize() {
120   registerAttributes();
121   registerTypes();
122 
123   // Add SPIR-V ops.
124   addOperations<
125 #define GET_OP_LIST
126 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
127       >();
128 
129   addInterfaces<SPIRVInlinerInterface>();
130 
131   // Allow unknown operations because SPIR-V is extensible.
132   allowUnknownOperations();
133 }
134 
getAttributeName(Decoration decoration)135 std::string SPIRVDialect::getAttributeName(Decoration decoration) {
136   return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration));
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Type Parsing
141 //===----------------------------------------------------------------------===//
142 
143 // Forward declarations.
144 template <typename ValTy>
145 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
146                                       DialectAsmParser &parser);
147 template <>
148 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
149                                     DialectAsmParser &parser);
150 
151 template <>
152 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
153                                             DialectAsmParser &parser);
154 
parseAndVerifyType(SPIRVDialect const & dialect,DialectAsmParser & parser)155 static Type parseAndVerifyType(SPIRVDialect const &dialect,
156                                DialectAsmParser &parser) {
157   Type type;
158   SMLoc typeLoc = parser.getCurrentLocation();
159   if (parser.parseType(type))
160     return Type();
161 
162   // Allow SPIR-V dialect types
163   if (&type.getDialect() == &dialect)
164     return type;
165 
166   // Check other allowed types
167   if (auto t = type.dyn_cast<FloatType>()) {
168     if (type.isBF16()) {
169       parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types");
170       return Type();
171     }
172   } else if (auto t = type.dyn_cast<IntegerType>()) {
173     if (!ScalarType::isValid(t)) {
174       parser.emitError(typeLoc,
175                        "only 1/8/16/32/64-bit integer type allowed but found ")
176           << type;
177       return Type();
178     }
179   } else if (auto t = type.dyn_cast<VectorType>()) {
180     if (t.getRank() != 1) {
181       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
182       return Type();
183     }
184     if (t.getNumElements() > 4) {
185       parser.emitError(
186           typeLoc, "vector length has to be less than or equal to 4 but found ")
187           << t.getNumElements();
188       return Type();
189     }
190   } else {
191     parser.emitError(typeLoc, "cannot use ")
192         << type << " to compose SPIR-V types";
193     return Type();
194   }
195 
196   return type;
197 }
198 
parseAndVerifyMatrixType(SPIRVDialect const & dialect,DialectAsmParser & parser)199 static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect,
200                                      DialectAsmParser &parser) {
201   Type type;
202   SMLoc typeLoc = parser.getCurrentLocation();
203   if (parser.parseType(type))
204     return Type();
205 
206   if (auto t = type.dyn_cast<VectorType>()) {
207     if (t.getRank() != 1) {
208       parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t;
209       return Type();
210     }
211     if (t.getNumElements() > 4 || t.getNumElements() < 2) {
212       parser.emitError(typeLoc,
213                        "matrix columns size has to be less than or equal "
214                        "to 4 and greater than or equal 2, but found ")
215           << t.getNumElements();
216       return Type();
217     }
218 
219     if (!t.getElementType().isa<FloatType>()) {
220       parser.emitError(typeLoc, "matrix columns' elements must be of "
221                                 "Float type, got ")
222           << t.getElementType();
223       return Type();
224     }
225   } else {
226     parser.emitError(typeLoc, "matrix must be composed using vector "
227                               "type, got ")
228         << type;
229     return Type();
230   }
231 
232   return type;
233 }
234 
parseAndVerifySampledImageType(SPIRVDialect const & dialect,DialectAsmParser & parser)235 static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect,
236                                            DialectAsmParser &parser) {
237   Type type;
238   SMLoc typeLoc = parser.getCurrentLocation();
239   if (parser.parseType(type))
240     return Type();
241 
242   if (!type.isa<ImageType>()) {
243     parser.emitError(typeLoc,
244                      "sampled image must be composed using image type, got ")
245         << type;
246     return Type();
247   }
248 
249   return type;
250 }
251 
252 /// Parses an optional `, stride = N` assembly segment. If no parsing failure
253 /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if
254 /// missing.
parseOptionalArrayStride(const SPIRVDialect & dialect,DialectAsmParser & parser,unsigned & stride)255 static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect,
256                                               DialectAsmParser &parser,
257                                               unsigned &stride) {
258   if (failed(parser.parseOptionalComma())) {
259     stride = 0;
260     return success();
261   }
262 
263   if (parser.parseKeyword("stride") || parser.parseEqual())
264     return failure();
265 
266   SMLoc strideLoc = parser.getCurrentLocation();
267   Optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser);
268   if (!optStride)
269     return failure();
270 
271   if (!(stride = *optStride)) {
272     parser.emitError(strideLoc, "ArrayStride must be greater than zero");
273     return failure();
274   }
275   return success();
276 }
277 
278 // element-type ::= integer-type
279 //                | floating-point-type
280 //                | vector-type
281 //                | spirv-type
282 //
283 // array-type ::= `!spv.array` `<` integer-literal `x` element-type
284 //                (`,` `stride` `=` integer-literal)? `>`
parseArrayType(SPIRVDialect const & dialect,DialectAsmParser & parser)285 static Type parseArrayType(SPIRVDialect const &dialect,
286                            DialectAsmParser &parser) {
287   if (parser.parseLess())
288     return Type();
289 
290   SmallVector<int64_t, 1> countDims;
291   SMLoc countLoc = parser.getCurrentLocation();
292   if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
293     return Type();
294   if (countDims.size() != 1) {
295     parser.emitError(countLoc,
296                      "expected single integer for array element count");
297     return Type();
298   }
299 
300   // According to the SPIR-V spec:
301   // "Length is the number of elements in the array. It must be at least 1."
302   int64_t count = countDims[0];
303   if (count == 0) {
304     parser.emitError(countLoc, "expected array length greater than 0");
305     return Type();
306   }
307 
308   Type elementType = parseAndVerifyType(dialect, parser);
309   if (!elementType)
310     return Type();
311 
312   unsigned stride = 0;
313   if (failed(parseOptionalArrayStride(dialect, parser, stride)))
314     return Type();
315 
316   if (parser.parseGreater())
317     return Type();
318   return ArrayType::get(elementType, count, stride);
319 }
320 
321 // cooperative-matrix-type ::= `!spv.coopmatrix` `<` element-type ',' scope ','
322 //                                                   rows ',' columns>`
parseCooperativeMatrixType(SPIRVDialect const & dialect,DialectAsmParser & parser)323 static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
324                                        DialectAsmParser &parser) {
325   if (parser.parseLess())
326     return Type();
327 
328   SmallVector<int64_t, 2> dims;
329   SMLoc countLoc = parser.getCurrentLocation();
330   if (parser.parseDimensionList(dims, /*allowDynamic=*/false))
331     return Type();
332 
333   if (dims.size() != 2) {
334     parser.emitError(countLoc, "expected rows and columns size");
335     return Type();
336   }
337 
338   auto elementTy = parseAndVerifyType(dialect, parser);
339   if (!elementTy)
340     return Type();
341 
342   Scope scope;
343   if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>"))
344     return Type();
345 
346   if (parser.parseGreater())
347     return Type();
348   return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]);
349 }
350 
351 // TODO: Reorder methods to be utilities first and parse*Type
352 // methods in alphabetical order
353 //
354 // storage-class ::= `UniformConstant`
355 //                 | `Uniform`
356 //                 | `Workgroup`
357 //                 | <and other storage classes...>
358 //
359 // pointer-type ::= `!spv.ptr<` element-type `,` storage-class `>`
parsePointerType(SPIRVDialect const & dialect,DialectAsmParser & parser)360 static Type parsePointerType(SPIRVDialect const &dialect,
361                              DialectAsmParser &parser) {
362   if (parser.parseLess())
363     return Type();
364 
365   auto pointeeType = parseAndVerifyType(dialect, parser);
366   if (!pointeeType)
367     return Type();
368 
369   StringRef storageClassSpec;
370   SMLoc storageClassLoc = parser.getCurrentLocation();
371   if (parser.parseComma() || parser.parseKeyword(&storageClassSpec))
372     return Type();
373 
374   auto storageClass = symbolizeStorageClass(storageClassSpec);
375   if (!storageClass) {
376     parser.emitError(storageClassLoc, "unknown storage class: ")
377         << storageClassSpec;
378     return Type();
379   }
380   if (parser.parseGreater())
381     return Type();
382   return PointerType::get(pointeeType, *storageClass);
383 }
384 
385 // runtime-array-type ::= `!spv.rtarray` `<` element-type
386 //                        (`,` `stride` `=` integer-literal)? `>`
parseRuntimeArrayType(SPIRVDialect const & dialect,DialectAsmParser & parser)387 static Type parseRuntimeArrayType(SPIRVDialect const &dialect,
388                                   DialectAsmParser &parser) {
389   if (parser.parseLess())
390     return Type();
391 
392   Type elementType = parseAndVerifyType(dialect, parser);
393   if (!elementType)
394     return Type();
395 
396   unsigned stride = 0;
397   if (failed(parseOptionalArrayStride(dialect, parser, stride)))
398     return Type();
399 
400   if (parser.parseGreater())
401     return Type();
402   return RuntimeArrayType::get(elementType, stride);
403 }
404 
405 // matrix-type ::= `!spv.matrix` `<` integer-literal `x` element-type `>`
parseMatrixType(SPIRVDialect const & dialect,DialectAsmParser & parser)406 static Type parseMatrixType(SPIRVDialect const &dialect,
407                             DialectAsmParser &parser) {
408   if (parser.parseLess())
409     return Type();
410 
411   SmallVector<int64_t, 1> countDims;
412   SMLoc countLoc = parser.getCurrentLocation();
413   if (parser.parseDimensionList(countDims, /*allowDynamic=*/false))
414     return Type();
415   if (countDims.size() != 1) {
416     parser.emitError(countLoc, "expected single unsigned "
417                                "integer for number of columns");
418     return Type();
419   }
420 
421   int64_t columnCount = countDims[0];
422   // According to the specification, Matrices can have 2, 3, or 4 columns
423   if (columnCount < 2 || columnCount > 4) {
424     parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 "
425                                "columns");
426     return Type();
427   }
428 
429   Type columnType = parseAndVerifyMatrixType(dialect, parser);
430   if (!columnType)
431     return Type();
432 
433   if (parser.parseGreater())
434     return Type();
435 
436   return MatrixType::get(columnType, columnCount);
437 }
438 
439 // Specialize this function to parse each of the parameters that define an
440 // ImageType. By default it assumes this is an enum type.
441 template <typename ValTy>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)442 static Optional<ValTy> parseAndVerify(SPIRVDialect const &dialect,
443                                       DialectAsmParser &parser) {
444   StringRef enumSpec;
445   SMLoc enumLoc = parser.getCurrentLocation();
446   if (parser.parseKeyword(&enumSpec)) {
447     return llvm::None;
448   }
449 
450   auto val = spirv::symbolizeEnum<ValTy>(enumSpec);
451   if (!val)
452     parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'";
453   return val;
454 }
455 
456 template <>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)457 Optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect,
458                                     DialectAsmParser &parser) {
459   // TODO: Further verify that the element type can be sampled
460   auto ty = parseAndVerifyType(dialect, parser);
461   if (!ty)
462     return llvm::None;
463   return ty;
464 }
465 
466 template <typename IntTy>
parseAndVerifyInteger(SPIRVDialect const & dialect,DialectAsmParser & parser)467 static Optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect,
468                                              DialectAsmParser &parser) {
469   IntTy offsetVal = std::numeric_limits<IntTy>::max();
470   if (parser.parseInteger(offsetVal))
471     return llvm::None;
472   return offsetVal;
473 }
474 
475 template <>
parseAndVerify(SPIRVDialect const & dialect,DialectAsmParser & parser)476 Optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect,
477                                             DialectAsmParser &parser) {
478   return parseAndVerifyInteger<unsigned>(dialect, parser);
479 }
480 
481 namespace {
482 // Functor object to parse a comma separated list of specs. The function
483 // parseAndVerify does the actual parsing and verification of individual
484 // elements. This is a functor since parsing the last element of the list
485 // (termination condition) needs partial specialization.
486 template <typename ParseType, typename... Args>
487 struct ParseCommaSeparatedList {
488   Optional<std::tuple<ParseType, Args...>>
operator ()__anon52aeee350311::ParseCommaSeparatedList489   operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const {
490     auto parseVal = parseAndVerify<ParseType>(dialect, parser);
491     if (!parseVal)
492       return llvm::None;
493 
494     auto numArgs = std::tuple_size<std::tuple<Args...>>::value;
495     if (numArgs != 0 && failed(parser.parseComma()))
496       return llvm::None;
497     auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser);
498     if (!remainingValues)
499       return llvm::None;
500     return std::tuple_cat(std::tuple<ParseType>(parseVal.value()),
501                           remainingValues.value());
502   }
503 };
504 
505 // Partial specialization of the function to parse a comma separated list of
506 // specs to parse the last element of the list.
507 template <typename ParseType>
508 struct ParseCommaSeparatedList<ParseType> {
operator ()__anon52aeee350311::ParseCommaSeparatedList509   Optional<std::tuple<ParseType>> operator()(SPIRVDialect const &dialect,
510                                              DialectAsmParser &parser) const {
511     if (auto value = parseAndVerify<ParseType>(dialect, parser))
512       return std::tuple<ParseType>(*value);
513     return llvm::None;
514   }
515 };
516 } // namespace
517 
518 // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...>
519 //
520 // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown`
521 //
522 // arrayed-info ::= `NonArrayed` | `Arrayed`
523 //
524 // sampling-info ::= `SingleSampled` | `MultiSampled`
525 //
526 // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` |  `NoSampler`
527 //
528 // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...>
529 //
530 // image-type ::= `!spv.image<` element-type `,` dim `,` depth-info `,`
531 //                              arrayed-info `,` sampling-info `,`
532 //                              sampler-use-info `,` format `>`
parseImageType(SPIRVDialect const & dialect,DialectAsmParser & parser)533 static Type parseImageType(SPIRVDialect const &dialect,
534                            DialectAsmParser &parser) {
535   if (parser.parseLess())
536     return Type();
537 
538   auto value =
539       ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo,
540                               ImageSamplingInfo, ImageSamplerUseInfo,
541                               ImageFormat>{}(dialect, parser);
542   if (!value)
543     return Type();
544 
545   if (parser.parseGreater())
546     return Type();
547   return ImageType::get(*value);
548 }
549 
550 // sampledImage-type :: = `!spv.sampledImage<` image-type `>`
parseSampledImageType(SPIRVDialect const & dialect,DialectAsmParser & parser)551 static Type parseSampledImageType(SPIRVDialect const &dialect,
552                                   DialectAsmParser &parser) {
553   if (parser.parseLess())
554     return Type();
555 
556   Type parsedType = parseAndVerifySampledImageType(dialect, parser);
557   if (!parsedType)
558     return Type();
559 
560   if (parser.parseGreater())
561     return Type();
562   return SampledImageType::get(parsedType);
563 }
564 
565 // Parse decorations associated with a member.
parseStructMemberDecorations(SPIRVDialect const & dialect,DialectAsmParser & parser,ArrayRef<Type> memberTypes,SmallVectorImpl<StructType::OffsetInfo> & offsetInfo,SmallVectorImpl<StructType::MemberDecorationInfo> & memberDecorationInfo)566 static ParseResult parseStructMemberDecorations(
567     SPIRVDialect const &dialect, DialectAsmParser &parser,
568     ArrayRef<Type> memberTypes,
569     SmallVectorImpl<StructType::OffsetInfo> &offsetInfo,
570     SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) {
571 
572   // Check if the first element is offset.
573   SMLoc offsetLoc = parser.getCurrentLocation();
574   StructType::OffsetInfo offset = 0;
575   OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset);
576   if (offsetParseResult.hasValue()) {
577     if (failed(*offsetParseResult))
578       return failure();
579 
580     if (offsetInfo.size() != memberTypes.size() - 1) {
581       return parser.emitError(offsetLoc,
582                               "offset specification must be given for "
583                               "all members");
584     }
585     offsetInfo.push_back(offset);
586   }
587 
588   // Check for no spirv::Decorations.
589   if (succeeded(parser.parseOptionalRSquare()))
590     return success();
591 
592   // If there was an offset, make sure to parse the comma.
593   if (offsetParseResult.hasValue() && parser.parseComma())
594     return failure();
595 
596   // Check for spirv::Decorations.
597   auto parseDecorations = [&]() {
598     auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser);
599     if (!memberDecoration)
600       return failure();
601 
602     // Parse member decoration value if it exists.
603     if (succeeded(parser.parseOptionalEqual())) {
604       auto memberDecorationValue =
605           parseAndVerifyInteger<uint32_t>(dialect, parser);
606 
607       if (!memberDecorationValue)
608         return failure();
609 
610       memberDecorationInfo.emplace_back(
611           static_cast<uint32_t>(memberTypes.size() - 1), 1,
612           memberDecoration.value(), memberDecorationValue.value());
613     } else {
614       memberDecorationInfo.emplace_back(
615           static_cast<uint32_t>(memberTypes.size() - 1), 0,
616           memberDecoration.value(), 0);
617     }
618     return success();
619   };
620   if (failed(parser.parseCommaSeparatedList(parseDecorations)) ||
621       failed(parser.parseRSquare()))
622     return failure();
623 
624   return success();
625 }
626 
627 // struct-member-decoration ::= integer-literal? spirv-decoration*
628 // struct-type ::=
629 //             `!spv.struct<` (id `,`)?
630 //                          `(`
631 //                            (spirv-type (`[` struct-member-decoration `]`)?)*
632 //                          `)>`
parseStructType(SPIRVDialect const & dialect,DialectAsmParser & parser)633 static Type parseStructType(SPIRVDialect const &dialect,
634                             DialectAsmParser &parser) {
635   // TODO: This function is quite lengthy. Break it down into smaller chunks.
636 
637   // To properly resolve recursive references while parsing recursive struct
638   // types, we need to maintain a list of enclosing struct type names. This set
639   // maintains the names of struct types in which the type we are about to parse
640   // is nested.
641   //
642   // Note: This has to be thread_local to enable multiple threads to safely
643   // parse concurrently.
644   thread_local SetVector<StringRef> structContext;
645 
646   static auto removeIdentifierAndFail = [](SetVector<StringRef> &structContext,
647                                            StringRef identifier) {
648     if (!identifier.empty())
649       structContext.remove(identifier);
650 
651     return Type();
652   };
653 
654   if (parser.parseLess())
655     return Type();
656 
657   StringRef identifier;
658 
659   // Check if this is an identified struct type.
660   if (succeeded(parser.parseOptionalKeyword(&identifier))) {
661     // Check if this is a possible recursive reference.
662     if (succeeded(parser.parseOptionalGreater())) {
663       if (structContext.count(identifier) == 0) {
664         parser.emitError(
665             parser.getNameLoc(),
666             "recursive struct reference not nested in struct definition");
667 
668         return Type();
669       }
670 
671       return StructType::getIdentified(dialect.getContext(), identifier);
672     }
673 
674     if (failed(parser.parseComma()))
675       return Type();
676 
677     if (structContext.count(identifier) != 0) {
678       parser.emitError(parser.getNameLoc(),
679                        "identifier already used for an enclosing struct");
680 
681       return removeIdentifierAndFail(structContext, identifier);
682     }
683 
684     structContext.insert(identifier);
685   }
686 
687   if (failed(parser.parseLParen()))
688     return removeIdentifierAndFail(structContext, identifier);
689 
690   if (succeeded(parser.parseOptionalRParen()) &&
691       succeeded(parser.parseOptionalGreater())) {
692     if (!identifier.empty())
693       structContext.remove(identifier);
694 
695     return StructType::getEmpty(dialect.getContext(), identifier);
696   }
697 
698   StructType idStructTy;
699 
700   if (!identifier.empty())
701     idStructTy = StructType::getIdentified(dialect.getContext(), identifier);
702 
703   SmallVector<Type, 4> memberTypes;
704   SmallVector<StructType::OffsetInfo, 4> offsetInfo;
705   SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo;
706 
707   do {
708     Type memberType;
709     if (parser.parseType(memberType))
710       return removeIdentifierAndFail(structContext, identifier);
711     memberTypes.push_back(memberType);
712 
713     if (succeeded(parser.parseOptionalLSquare()))
714       if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo,
715                                        memberDecorationInfo))
716         return removeIdentifierAndFail(structContext, identifier);
717   } while (succeeded(parser.parseOptionalComma()));
718 
719   if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) {
720     parser.emitError(parser.getNameLoc(),
721                      "offset specification must be given for all members");
722     return removeIdentifierAndFail(structContext, identifier);
723   }
724 
725   if (failed(parser.parseRParen()) || failed(parser.parseGreater()))
726     return removeIdentifierAndFail(structContext, identifier);
727 
728   if (!identifier.empty()) {
729     if (failed(idStructTy.trySetBody(memberTypes, offsetInfo,
730                                      memberDecorationInfo)))
731       return Type();
732 
733     structContext.remove(identifier);
734     return idStructTy;
735   }
736 
737   return StructType::get(memberTypes, offsetInfo, memberDecorationInfo);
738 }
739 
740 // spirv-type ::= array-type
741 //              | element-type
742 //              | image-type
743 //              | pointer-type
744 //              | runtime-array-type
745 //              | sampled-image-type
746 //              | struct-type
parseType(DialectAsmParser & parser) const747 Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
748   StringRef keyword;
749   if (parser.parseKeyword(&keyword))
750     return Type();
751 
752   if (keyword == "array")
753     return parseArrayType(*this, parser);
754   if (keyword == "coopmatrix")
755     return parseCooperativeMatrixType(*this, parser);
756   if (keyword == "image")
757     return parseImageType(*this, parser);
758   if (keyword == "ptr")
759     return parsePointerType(*this, parser);
760   if (keyword == "rtarray")
761     return parseRuntimeArrayType(*this, parser);
762   if (keyword == "sampled_image")
763     return parseSampledImageType(*this, parser);
764   if (keyword == "struct")
765     return parseStructType(*this, parser);
766   if (keyword == "matrix")
767     return parseMatrixType(*this, parser);
768   parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
769   return Type();
770 }
771 
772 //===----------------------------------------------------------------------===//
773 // Type Printing
774 //===----------------------------------------------------------------------===//
775 
print(ArrayType type,DialectAsmPrinter & os)776 static void print(ArrayType type, DialectAsmPrinter &os) {
777   os << "array<" << type.getNumElements() << " x " << type.getElementType();
778   if (unsigned stride = type.getArrayStride())
779     os << ", stride=" << stride;
780   os << ">";
781 }
782 
print(RuntimeArrayType type,DialectAsmPrinter & os)783 static void print(RuntimeArrayType type, DialectAsmPrinter &os) {
784   os << "rtarray<" << type.getElementType();
785   if (unsigned stride = type.getArrayStride())
786     os << ", stride=" << stride;
787   os << ">";
788 }
789 
print(PointerType type,DialectAsmPrinter & os)790 static void print(PointerType type, DialectAsmPrinter &os) {
791   os << "ptr<" << type.getPointeeType() << ", "
792      << stringifyStorageClass(type.getStorageClass()) << ">";
793 }
794 
print(ImageType type,DialectAsmPrinter & os)795 static void print(ImageType type, DialectAsmPrinter &os) {
796   os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim())
797      << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", "
798      << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", "
799      << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", "
800      << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", "
801      << stringifyImageFormat(type.getImageFormat()) << ">";
802 }
803 
print(SampledImageType type,DialectAsmPrinter & os)804 static void print(SampledImageType type, DialectAsmPrinter &os) {
805   os << "sampled_image<" << type.getImageType() << ">";
806 }
807 
print(StructType type,DialectAsmPrinter & os)808 static void print(StructType type, DialectAsmPrinter &os) {
809   thread_local SetVector<StringRef> structContext;
810 
811   os << "struct<";
812 
813   if (type.isIdentified()) {
814     os << type.getIdentifier();
815 
816     if (structContext.count(type.getIdentifier())) {
817       os << ">";
818       return;
819     }
820 
821     os << ", ";
822     structContext.insert(type.getIdentifier());
823   }
824 
825   os << "(";
826 
827   auto printMember = [&](unsigned i) {
828     os << type.getElementType(i);
829     SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations;
830     type.getMemberDecorations(i, decorations);
831     if (type.hasOffset() || !decorations.empty()) {
832       os << " [";
833       if (type.hasOffset()) {
834         os << type.getMemberOffset(i);
835         if (!decorations.empty())
836           os << ", ";
837       }
838       auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) {
839         os << stringifyDecoration(decoration.decoration);
840         if (decoration.hasValue) {
841           os << "=" << decoration.decorationValue;
842         }
843       };
844       llvm::interleaveComma(decorations, os, eachFn);
845       os << "]";
846     }
847   };
848   llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os,
849                         printMember);
850   os << ")>";
851 
852   if (type.isIdentified())
853     structContext.remove(type.getIdentifier());
854 }
855 
print(CooperativeMatrixNVType type,DialectAsmPrinter & os)856 static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) {
857   os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x";
858   os << type.getElementType() << ", " << stringifyScope(type.getScope());
859   os << ">";
860 }
861 
print(MatrixType type,DialectAsmPrinter & os)862 static void print(MatrixType type, DialectAsmPrinter &os) {
863   os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType();
864   os << ">";
865 }
866 
printType(Type type,DialectAsmPrinter & os) const867 void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
868   TypeSwitch<Type>(type)
869       .Case<ArrayType, CooperativeMatrixNVType, PointerType, RuntimeArrayType,
870             ImageType, SampledImageType, StructType, MatrixType>(
871           [&](auto type) { print(type, os); })
872       .Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
873 }
874 
875 //===----------------------------------------------------------------------===//
876 // Constant
877 //===----------------------------------------------------------------------===//
878 
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)879 Operation *SPIRVDialect::materializeConstant(OpBuilder &builder,
880                                              Attribute value, Type type,
881                                              Location loc) {
882   if (!spirv::ConstantOp::isBuildableWith(type))
883     return nullptr;
884 
885   return builder.create<spirv::ConstantOp>(loc, type, value);
886 }
887 
888 //===----------------------------------------------------------------------===//
889 // Shader Interface ABI
890 //===----------------------------------------------------------------------===//
891 
verifyOperationAttribute(Operation * op,NamedAttribute attribute)892 LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op,
893                                                      NamedAttribute attribute) {
894   StringRef symbol = attribute.getName().strref();
895   Attribute attr = attribute.getValue();
896 
897   if (symbol == spirv::getEntryPointABIAttrName()) {
898     if (!attr.isa<spirv::EntryPointABIAttr>()) {
899       return op->emitError("'")
900              << symbol << "' attribute must be an entry point ABI attribute";
901     }
902   } else if (symbol == spirv::getTargetEnvAttrName()) {
903     if (!attr.isa<spirv::TargetEnvAttr>())
904       return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr";
905   } else {
906     return op->emitError("found unsupported '")
907            << symbol << "' attribute on operation";
908   }
909 
910   return success();
911 }
912 
913 /// Verifies the given SPIR-V `attribute` attached to a value of the given
914 /// `valueType` is valid.
verifyRegionAttribute(Location loc,Type valueType,NamedAttribute attribute)915 static LogicalResult verifyRegionAttribute(Location loc, Type valueType,
916                                            NamedAttribute attribute) {
917   StringRef symbol = attribute.getName().strref();
918   Attribute attr = attribute.getValue();
919 
920   if (symbol != spirv::getInterfaceVarABIAttrName())
921     return emitError(loc, "found unsupported '")
922            << symbol << "' attribute on region argument";
923 
924   auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>();
925   if (!varABIAttr)
926     return emitError(loc, "'")
927            << symbol << "' must be a spirv::InterfaceVarABIAttr";
928 
929   if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat())
930     return emitError(loc, "'") << symbol
931                                << "' attribute cannot specify storage class "
932                                   "when attaching to a non-scalar value";
933 
934   return success();
935 }
936 
verifyRegionArgAttribute(Operation * op,unsigned regionIndex,unsigned argIndex,NamedAttribute attribute)937 LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op,
938                                                      unsigned regionIndex,
939                                                      unsigned argIndex,
940                                                      NamedAttribute attribute) {
941   return verifyRegionAttribute(
942       op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(),
943       attribute);
944 }
945 
verifyRegionResultAttribute(Operation * op,unsigned,unsigned,NamedAttribute attribute)946 LogicalResult SPIRVDialect::verifyRegionResultAttribute(
947     Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/,
948     NamedAttribute attribute) {
949   return op->emitError("cannot attach SPIR-V attributes to region result");
950 }
951