1 //===- SPIRVAttributes.cpp - SPIR-V attribute definitions -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "llvm/ADT/TypeSwitch.h"
15 
16 using namespace mlir;
17 using namespace mlir::spirv;
18 
19 //===----------------------------------------------------------------------===//
20 // TableGen'erated attribute utility functions
21 //===----------------------------------------------------------------------===//
22 
23 namespace mlir {
24 namespace spirv {
25 #include "mlir/Dialect/SPIRV/IR/SPIRVAttrUtils.inc"
26 } // namespace spirv
27 
28 //===----------------------------------------------------------------------===//
29 // Attribute storage classes
30 //===----------------------------------------------------------------------===//
31 
32 namespace spirv {
33 namespace detail {
34 
35 struct InterfaceVarABIAttributeStorage : public AttributeStorage {
36   using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
37 
InterfaceVarABIAttributeStoragemlir::spirv::detail::InterfaceVarABIAttributeStorage38   InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding,
39                                   Attribute storageClass)
40       : descriptorSet(descriptorSet), binding(binding),
41         storageClass(storageClass) {}
42 
operator ==mlir::spirv::detail::InterfaceVarABIAttributeStorage43   bool operator==(const KeyTy &key) const {
44     return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding &&
45            std::get<2>(key) == storageClass;
46   }
47 
48   static InterfaceVarABIAttributeStorage *
constructmlir::spirv::detail::InterfaceVarABIAttributeStorage49   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
50     return new (allocator.allocate<InterfaceVarABIAttributeStorage>())
51         InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key),
52                                         std::get<2>(key));
53   }
54 
55   Attribute descriptorSet;
56   Attribute binding;
57   Attribute storageClass;
58 };
59 
60 struct VerCapExtAttributeStorage : public AttributeStorage {
61   using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
62 
VerCapExtAttributeStoragemlir::spirv::detail::VerCapExtAttributeStorage63   VerCapExtAttributeStorage(Attribute version, Attribute capabilities,
64                             Attribute extensions)
65       : version(version), capabilities(capabilities), extensions(extensions) {}
66 
operator ==mlir::spirv::detail::VerCapExtAttributeStorage67   bool operator==(const KeyTy &key) const {
68     return std::get<0>(key) == version && std::get<1>(key) == capabilities &&
69            std::get<2>(key) == extensions;
70   }
71 
72   static VerCapExtAttributeStorage *
constructmlir::spirv::detail::VerCapExtAttributeStorage73   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
74     return new (allocator.allocate<VerCapExtAttributeStorage>())
75         VerCapExtAttributeStorage(std::get<0>(key), std::get<1>(key),
76                                   std::get<2>(key));
77   }
78 
79   Attribute version;
80   Attribute capabilities;
81   Attribute extensions;
82 };
83 
84 struct TargetEnvAttributeStorage : public AttributeStorage {
85   using KeyTy = std::tuple<Attribute, Vendor, DeviceType, uint32_t, Attribute>;
86 
TargetEnvAttributeStoragemlir::spirv::detail::TargetEnvAttributeStorage87   TargetEnvAttributeStorage(Attribute triple, Vendor vendorID,
88                             DeviceType deviceType, uint32_t deviceID,
89                             Attribute limits)
90       : triple(triple), limits(limits), vendorID(vendorID),
91         deviceType(deviceType), deviceID(deviceID) {}
92 
operator ==mlir::spirv::detail::TargetEnvAttributeStorage93   bool operator==(const KeyTy &key) const {
94     return key ==
95            std::make_tuple(triple, vendorID, deviceType, deviceID, limits);
96   }
97 
98   static TargetEnvAttributeStorage *
constructmlir::spirv::detail::TargetEnvAttributeStorage99   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
100     return new (allocator.allocate<TargetEnvAttributeStorage>())
101         TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key),
102                                   std::get<2>(key), std::get<3>(key),
103                                   std::get<4>(key));
104   }
105 
106   Attribute triple;
107   Attribute limits;
108   Vendor vendorID;
109   DeviceType deviceType;
110   uint32_t deviceID;
111 };
112 } // namespace detail
113 } // namespace spirv
114 } // namespace mlir
115 
116 //===----------------------------------------------------------------------===//
117 // InterfaceVarABIAttr
118 //===----------------------------------------------------------------------===//
119 
120 spirv::InterfaceVarABIAttr
get(uint32_t descriptorSet,uint32_t binding,Optional<spirv::StorageClass> storageClass,MLIRContext * context)121 spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding,
122                                 Optional<spirv::StorageClass> storageClass,
123                                 MLIRContext *context) {
124   Builder b(context);
125   auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet);
126   auto bindingAttr = b.getI32IntegerAttr(binding);
127   auto storageClassAttr =
128       storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass))
129                    : IntegerAttr();
130   return get(descriptorSetAttr, bindingAttr, storageClassAttr);
131 }
132 
133 spirv::InterfaceVarABIAttr
get(IntegerAttr descriptorSet,IntegerAttr binding,IntegerAttr storageClass)134 spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
135                                 IntegerAttr storageClass) {
136   assert(descriptorSet && binding);
137   MLIRContext *context = descriptorSet.getContext();
138   return Base::get(context, descriptorSet, binding, storageClass);
139 }
140 
getKindName()141 StringRef spirv::InterfaceVarABIAttr::getKindName() {
142   return "interface_var_abi";
143 }
144 
getBinding()145 uint32_t spirv::InterfaceVarABIAttr::getBinding() {
146   return getImpl()->binding.cast<IntegerAttr>().getInt();
147 }
148 
getDescriptorSet()149 uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
150   return getImpl()->descriptorSet.cast<IntegerAttr>().getInt();
151 }
152 
getStorageClass()153 Optional<spirv::StorageClass> spirv::InterfaceVarABIAttr::getStorageClass() {
154   if (getImpl()->storageClass)
155     return static_cast<spirv::StorageClass>(
156         getImpl()->storageClass.cast<IntegerAttr>().getValue().getZExtValue());
157   return llvm::None;
158 }
159 
verify(function_ref<InFlightDiagnostic ()> emitError,IntegerAttr descriptorSet,IntegerAttr binding,IntegerAttr storageClass)160 LogicalResult spirv::InterfaceVarABIAttr::verify(
161     function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
162     IntegerAttr binding, IntegerAttr storageClass) {
163   if (!descriptorSet.getType().isSignlessInteger(32))
164     return emitError() << "expected 32-bit integer for descriptor set";
165 
166   if (!binding.getType().isSignlessInteger(32))
167     return emitError() << "expected 32-bit integer for binding";
168 
169   if (storageClass) {
170     if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) {
171       auto storageClassValue =
172           spirv::symbolizeStorageClass(storageClassAttr.getInt());
173       if (!storageClassValue)
174         return emitError() << "unknown storage class";
175     } else {
176       return emitError() << "expected valid storage class";
177     }
178   }
179 
180   return success();
181 }
182 
183 //===----------------------------------------------------------------------===//
184 // VerCapExtAttr
185 //===----------------------------------------------------------------------===//
186 
get(spirv::Version version,ArrayRef<spirv::Capability> capabilities,ArrayRef<spirv::Extension> extensions,MLIRContext * context)187 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(
188     spirv::Version version, ArrayRef<spirv::Capability> capabilities,
189     ArrayRef<spirv::Extension> extensions, MLIRContext *context) {
190   Builder b(context);
191 
192   auto versionAttr = b.getI32IntegerAttr(static_cast<uint32_t>(version));
193 
194   SmallVector<Attribute, 4> capAttrs;
195   capAttrs.reserve(capabilities.size());
196   for (spirv::Capability cap : capabilities)
197     capAttrs.push_back(b.getI32IntegerAttr(static_cast<uint32_t>(cap)));
198 
199   SmallVector<Attribute, 4> extAttrs;
200   extAttrs.reserve(extensions.size());
201   for (spirv::Extension ext : extensions)
202     extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext)));
203 
204   return get(versionAttr, b.getArrayAttr(capAttrs), b.getArrayAttr(extAttrs));
205 }
206 
get(IntegerAttr version,ArrayAttr capabilities,ArrayAttr extensions)207 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version,
208                                                ArrayAttr capabilities,
209                                                ArrayAttr extensions) {
210   assert(version && capabilities && extensions);
211   MLIRContext *context = version.getContext();
212   return Base::get(context, version, capabilities, extensions);
213 }
214 
getKindName()215 StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
216 
getVersion()217 spirv::Version spirv::VerCapExtAttr::getVersion() {
218   return static_cast<spirv::Version>(
219       getImpl()->version.cast<IntegerAttr>().getValue().getZExtValue());
220 }
221 
ext_iterator(ArrayAttr::iterator it)222 spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it)
223     : llvm::mapped_iterator<ArrayAttr::iterator,
224                             spirv::Extension (*)(Attribute)>(
225           it, [](Attribute attr) {
226             return *symbolizeExtension(attr.cast<StringAttr>().getValue());
227           }) {}
228 
getExtensions()229 spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
230   auto range = getExtensionsAttr().getValue();
231   return {ext_iterator(range.begin()), ext_iterator(range.end())};
232 }
233 
getExtensionsAttr()234 ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() {
235   return getImpl()->extensions.cast<ArrayAttr>();
236 }
237 
cap_iterator(ArrayAttr::iterator it)238 spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
239     : llvm::mapped_iterator<ArrayAttr::iterator,
240                             spirv::Capability (*)(Attribute)>(
241           it, [](Attribute attr) {
242             return *symbolizeCapability(
243                 attr.cast<IntegerAttr>().getValue().getZExtValue());
244           }) {}
245 
getCapabilities()246 spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
247   auto range = getCapabilitiesAttr().getValue();
248   return {cap_iterator(range.begin()), cap_iterator(range.end())};
249 }
250 
getCapabilitiesAttr()251 ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
252   return getImpl()->capabilities.cast<ArrayAttr>();
253 }
254 
255 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,IntegerAttr version,ArrayAttr capabilities,ArrayAttr extensions)256 spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
257                              IntegerAttr version, ArrayAttr capabilities,
258                              ArrayAttr extensions) {
259   if (!version.getType().isSignlessInteger(32))
260     return emitError() << "expected 32-bit integer for version";
261 
262   if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
263         if (auto intAttr = attr.dyn_cast<IntegerAttr>())
264           if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))
265             return true;
266         return false;
267       }))
268     return emitError() << "unknown capability in capability list";
269 
270   if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
271         if (auto strAttr = attr.dyn_cast<StringAttr>())
272           if (spirv::symbolizeExtension(strAttr.getValue()))
273             return true;
274         return false;
275       }))
276     return emitError() << "unknown extension in extension list";
277 
278   return success();
279 }
280 
281 //===----------------------------------------------------------------------===//
282 // TargetEnvAttr
283 //===----------------------------------------------------------------------===//
284 
get(spirv::VerCapExtAttr triple,Vendor vendorID,DeviceType deviceType,uint32_t deviceID,ResourceLimitsAttr limits)285 spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
286                                                Vendor vendorID,
287                                                DeviceType deviceType,
288                                                uint32_t deviceID,
289                                                ResourceLimitsAttr limits) {
290   assert(triple && limits && "expected valid triple and limits");
291   MLIRContext *context = triple.getContext();
292   return Base::get(context, triple, vendorID, deviceType, deviceID, limits);
293 }
294 
getKindName()295 StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
296 
getTripleAttr() const297 spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const {
298   return getImpl()->triple.cast<spirv::VerCapExtAttr>();
299 }
300 
getVersion() const301 spirv::Version spirv::TargetEnvAttr::getVersion() const {
302   return getTripleAttr().getVersion();
303 }
304 
getExtensions()305 spirv::VerCapExtAttr::ext_range spirv::TargetEnvAttr::getExtensions() {
306   return getTripleAttr().getExtensions();
307 }
308 
getExtensionsAttr()309 ArrayAttr spirv::TargetEnvAttr::getExtensionsAttr() {
310   return getTripleAttr().getExtensionsAttr();
311 }
312 
getCapabilities()313 spirv::VerCapExtAttr::cap_range spirv::TargetEnvAttr::getCapabilities() {
314   return getTripleAttr().getCapabilities();
315 }
316 
getCapabilitiesAttr()317 ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() {
318   return getTripleAttr().getCapabilitiesAttr();
319 }
320 
getVendorID() const321 spirv::Vendor spirv::TargetEnvAttr::getVendorID() const {
322   return getImpl()->vendorID;
323 }
324 
getDeviceType() const325 spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() const {
326   return getImpl()->deviceType;
327 }
328 
getDeviceID() const329 uint32_t spirv::TargetEnvAttr::getDeviceID() const {
330   return getImpl()->deviceID;
331 }
332 
getResourceLimits() const333 spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
334   return getImpl()->limits.cast<spirv::ResourceLimitsAttr>();
335 }
336 
337 //===----------------------------------------------------------------------===//
338 // ODS Generated Attributes
339 //===----------------------------------------------------------------------===//
340 
341 #define GET_ATTRDEF_CLASSES
342 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"
343 
344 //===----------------------------------------------------------------------===//
345 // Attribute Parsing
346 //===----------------------------------------------------------------------===//
347 
348 /// Parses a comma-separated list of keywords, invokes `processKeyword` on each
349 /// of the parsed keyword, and returns failure if any error occurs.
350 static ParseResult
parseKeywordList(DialectAsmParser & parser,function_ref<LogicalResult (SMLoc,StringRef)> processKeyword)351 parseKeywordList(DialectAsmParser &parser,
352                  function_ref<LogicalResult(SMLoc, StringRef)> processKeyword) {
353   if (parser.parseLSquare())
354     return failure();
355 
356   // Special case for empty list.
357   if (succeeded(parser.parseOptionalRSquare()))
358     return success();
359 
360   // Keep parsing the keyword and an optional comma following it. If the comma
361   // is successfully parsed, then we have more keywords to parse.
362   if (failed(parser.parseCommaSeparatedList([&]() {
363         auto loc = parser.getCurrentLocation();
364         StringRef keyword;
365         if (parser.parseKeyword(&keyword) ||
366             failed(processKeyword(loc, keyword)))
367           return failure();
368         return success();
369       })))
370     return failure();
371   return parser.parseRSquare();
372 }
373 
374 /// Parses a spirv::InterfaceVarABIAttr.
parseInterfaceVarABIAttr(DialectAsmParser & parser)375 static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) {
376   if (parser.parseLess())
377     return {};
378 
379   Builder &builder = parser.getBuilder();
380 
381   if (parser.parseLParen())
382     return {};
383 
384   IntegerAttr descriptorSetAttr;
385   {
386     auto loc = parser.getCurrentLocation();
387     uint32_t descriptorSet = 0;
388     auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet);
389 
390     if (!descriptorSetParseResult.hasValue() ||
391         failed(*descriptorSetParseResult)) {
392       parser.emitError(loc, "missing descriptor set");
393       return {};
394     }
395     descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet);
396   }
397 
398   if (parser.parseComma())
399     return {};
400 
401   IntegerAttr bindingAttr;
402   {
403     auto loc = parser.getCurrentLocation();
404     uint32_t binding = 0;
405     auto bindingParseResult = parser.parseOptionalInteger(binding);
406 
407     if (!bindingParseResult.hasValue() || failed(*bindingParseResult)) {
408       parser.emitError(loc, "missing binding");
409       return {};
410     }
411     bindingAttr = builder.getI32IntegerAttr(binding);
412   }
413 
414   if (parser.parseRParen())
415     return {};
416 
417   IntegerAttr storageClassAttr;
418   {
419     if (succeeded(parser.parseOptionalComma())) {
420       auto loc = parser.getCurrentLocation();
421       StringRef storageClass;
422       if (parser.parseKeyword(&storageClass))
423         return {};
424 
425       if (auto storageClassSymbol =
426               spirv::symbolizeStorageClass(storageClass)) {
427         storageClassAttr = builder.getI32IntegerAttr(
428             static_cast<uint32_t>(*storageClassSymbol));
429       } else {
430         parser.emitError(loc, "unknown storage class: ") << storageClass;
431         return {};
432       }
433     }
434   }
435 
436   if (parser.parseGreater())
437     return {};
438 
439   return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr,
440                                          storageClassAttr);
441 }
442 
parseVerCapExtAttr(DialectAsmParser & parser)443 static Attribute parseVerCapExtAttr(DialectAsmParser &parser) {
444   if (parser.parseLess())
445     return {};
446 
447   Builder &builder = parser.getBuilder();
448 
449   IntegerAttr versionAttr;
450   {
451     auto loc = parser.getCurrentLocation();
452     StringRef version;
453     if (parser.parseKeyword(&version) || parser.parseComma())
454       return {};
455 
456     if (auto versionSymbol = spirv::symbolizeVersion(version)) {
457       versionAttr =
458           builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol));
459     } else {
460       parser.emitError(loc, "unknown version: ") << version;
461       return {};
462     }
463   }
464 
465   ArrayAttr capabilitiesAttr;
466   {
467     SmallVector<Attribute, 4> capabilities;
468     SMLoc errorloc;
469     StringRef errorKeyword;
470 
471     auto processCapability = [&](SMLoc loc, StringRef capability) {
472       if (auto capSymbol = spirv::symbolizeCapability(capability)) {
473         capabilities.push_back(
474             builder.getI32IntegerAttr(static_cast<uint32_t>(*capSymbol)));
475         return success();
476       }
477       return errorloc = loc, errorKeyword = capability, failure();
478     };
479     if (parseKeywordList(parser, processCapability) || parser.parseComma()) {
480       if (!errorKeyword.empty())
481         parser.emitError(errorloc, "unknown capability: ") << errorKeyword;
482       return {};
483     }
484 
485     capabilitiesAttr = builder.getArrayAttr(capabilities);
486   }
487 
488   ArrayAttr extensionsAttr;
489   {
490     SmallVector<Attribute, 1> extensions;
491     SMLoc errorloc;
492     StringRef errorKeyword;
493 
494     auto processExtension = [&](SMLoc loc, StringRef extension) {
495       if (spirv::symbolizeExtension(extension)) {
496         extensions.push_back(builder.getStringAttr(extension));
497         return success();
498       }
499       return errorloc = loc, errorKeyword = extension, failure();
500     };
501     if (parseKeywordList(parser, processExtension)) {
502       if (!errorKeyword.empty())
503         parser.emitError(errorloc, "unknown extension: ") << errorKeyword;
504       return {};
505     }
506 
507     extensionsAttr = builder.getArrayAttr(extensions);
508   }
509 
510   if (parser.parseGreater())
511     return {};
512 
513   return spirv::VerCapExtAttr::get(versionAttr, capabilitiesAttr,
514                                    extensionsAttr);
515 }
516 
517 /// Parses a spirv::TargetEnvAttr.
parseTargetEnvAttr(DialectAsmParser & parser)518 static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
519   if (parser.parseLess())
520     return {};
521 
522   spirv::VerCapExtAttr tripleAttr;
523   if (parser.parseAttribute(tripleAttr) || parser.parseComma())
524     return {};
525 
526   // Parse [vendor[:device-type[:device-id]]]
527   Vendor vendorID = Vendor::Unknown;
528   DeviceType deviceType = DeviceType::Unknown;
529   uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID;
530   {
531     auto loc = parser.getCurrentLocation();
532     StringRef vendorStr;
533     if (succeeded(parser.parseOptionalKeyword(&vendorStr))) {
534       if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) {
535         vendorID = *vendorSymbol;
536       } else {
537         parser.emitError(loc, "unknown vendor: ") << vendorStr;
538       }
539 
540       if (succeeded(parser.parseOptionalColon())) {
541         loc = parser.getCurrentLocation();
542         StringRef deviceTypeStr;
543         if (parser.parseKeyword(&deviceTypeStr))
544           return {};
545         if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) {
546           deviceType = *deviceTypeSymbol;
547         } else {
548           parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
549         }
550 
551         if (succeeded(parser.parseOptionalColon())) {
552           loc = parser.getCurrentLocation();
553           if (parser.parseInteger(deviceID))
554             return {};
555         }
556       }
557       if (parser.parseComma())
558         return {};
559     }
560   }
561 
562   ResourceLimitsAttr limitsAttr;
563   if (parser.parseAttribute(limitsAttr) || parser.parseGreater())
564     return {};
565 
566   return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID,
567                                    limitsAttr);
568 }
569 
parseAttribute(DialectAsmParser & parser,Type type) const570 Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
571                                        Type type) const {
572   // SPIR-V attributes are dictionaries so they do not have type.
573   if (type) {
574     parser.emitError(parser.getNameLoc(), "unexpected type");
575     return {};
576   }
577 
578   // Parse the kind keyword first.
579   StringRef attrKind;
580   Attribute attr;
581   OptionalParseResult result =
582       generatedAttributeParser(parser, &attrKind, type, attr);
583   if (result.hasValue())
584     return attr;
585 
586   if (attrKind == spirv::TargetEnvAttr::getKindName())
587     return parseTargetEnvAttr(parser);
588   if (attrKind == spirv::VerCapExtAttr::getKindName())
589     return parseVerCapExtAttr(parser);
590   if (attrKind == spirv::InterfaceVarABIAttr::getKindName())
591     return parseInterfaceVarABIAttr(parser);
592 
593   parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ")
594       << attrKind;
595   return {};
596 }
597 
598 //===----------------------------------------------------------------------===//
599 // Attribute Printing
600 //===----------------------------------------------------------------------===//
601 
print(spirv::VerCapExtAttr triple,DialectAsmPrinter & printer)602 static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
603   auto &os = printer.getStream();
604   printer << spirv::VerCapExtAttr::getKindName() << "<"
605           << spirv::stringifyVersion(triple.getVersion()) << ", [";
606   llvm::interleaveComma(
607       triple.getCapabilities(), os,
608       [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); });
609   printer << "], [";
610   llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) {
611     os << attr.cast<StringAttr>().getValue();
612   });
613   printer << "]>";
614 }
615 
print(spirv::TargetEnvAttr targetEnv,DialectAsmPrinter & printer)616 static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
617   printer << spirv::TargetEnvAttr::getKindName() << "<#spv.";
618   print(targetEnv.getTripleAttr(), printer);
619   spirv::Vendor vendorID = targetEnv.getVendorID();
620   spirv::DeviceType deviceType = targetEnv.getDeviceType();
621   uint32_t deviceID = targetEnv.getDeviceID();
622   if (vendorID != spirv::Vendor::Unknown) {
623     printer << ", " << spirv::stringifyVendor(vendorID);
624     if (deviceType != spirv::DeviceType::Unknown) {
625       printer << ":" << spirv::stringifyDeviceType(deviceType);
626       if (deviceID != spirv::TargetEnvAttr::kUnknownDeviceID)
627         printer << ":" << deviceID;
628     }
629   }
630   printer << ", " << targetEnv.getResourceLimits() << ">";
631 }
632 
print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,DialectAsmPrinter & printer)633 static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,
634                   DialectAsmPrinter &printer) {
635   printer << spirv::InterfaceVarABIAttr::getKindName() << "<("
636           << interfaceVarABIAttr.getDescriptorSet() << ", "
637           << interfaceVarABIAttr.getBinding() << ")";
638   auto storageClass = interfaceVarABIAttr.getStorageClass();
639   if (storageClass)
640     printer << ", " << spirv::stringifyStorageClass(*storageClass);
641   printer << ">";
642 }
643 
printAttribute(Attribute attr,DialectAsmPrinter & printer) const644 void SPIRVDialect::printAttribute(Attribute attr,
645                                   DialectAsmPrinter &printer) const {
646   if (succeeded(generatedAttributePrinter(attr, printer)))
647     return;
648 
649   if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>())
650     print(targetEnv, printer);
651   else if (auto vceAttr = attr.dyn_cast<VerCapExtAttr>())
652     print(vceAttr, printer);
653   else if (auto interfaceVarABIAttr = attr.dyn_cast<InterfaceVarABIAttr>())
654     print(interfaceVarABIAttr, printer);
655   else
656     llvm_unreachable("unhandled SPIR-V attribute kind");
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // SPIR-V Dialect
661 //===----------------------------------------------------------------------===//
662 
registerAttributes()663 void spirv::SPIRVDialect::registerAttributes() {
664   addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
665   addAttributes<
666 #define GET_ATTRDEF_LIST
667 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"
668       >();
669 }
670