1 //===- SPIRVAttributes.cpp - SPIR-V attribute definitions -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "llvm/ADT/TypeSwitch.h"
15
16 using namespace mlir;
17 using namespace mlir::spirv;
18
19 //===----------------------------------------------------------------------===//
20 // TableGen'erated attribute utility functions
21 //===----------------------------------------------------------------------===//
22
23 namespace mlir {
24 namespace spirv {
25 #include "mlir/Dialect/SPIRV/IR/SPIRVAttrUtils.inc"
26 } // namespace spirv
27
28 //===----------------------------------------------------------------------===//
29 // Attribute storage classes
30 //===----------------------------------------------------------------------===//
31
32 namespace spirv {
33 namespace detail {
34
35 struct InterfaceVarABIAttributeStorage : public AttributeStorage {
36 using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
37
InterfaceVarABIAttributeStoragemlir::spirv::detail::InterfaceVarABIAttributeStorage38 InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding,
39 Attribute storageClass)
40 : descriptorSet(descriptorSet), binding(binding),
41 storageClass(storageClass) {}
42
operator ==mlir::spirv::detail::InterfaceVarABIAttributeStorage43 bool operator==(const KeyTy &key) const {
44 return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding &&
45 std::get<2>(key) == storageClass;
46 }
47
48 static InterfaceVarABIAttributeStorage *
constructmlir::spirv::detail::InterfaceVarABIAttributeStorage49 construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
50 return new (allocator.allocate<InterfaceVarABIAttributeStorage>())
51 InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key),
52 std::get<2>(key));
53 }
54
55 Attribute descriptorSet;
56 Attribute binding;
57 Attribute storageClass;
58 };
59
60 struct VerCapExtAttributeStorage : public AttributeStorage {
61 using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
62
VerCapExtAttributeStoragemlir::spirv::detail::VerCapExtAttributeStorage63 VerCapExtAttributeStorage(Attribute version, Attribute capabilities,
64 Attribute extensions)
65 : version(version), capabilities(capabilities), extensions(extensions) {}
66
operator ==mlir::spirv::detail::VerCapExtAttributeStorage67 bool operator==(const KeyTy &key) const {
68 return std::get<0>(key) == version && std::get<1>(key) == capabilities &&
69 std::get<2>(key) == extensions;
70 }
71
72 static VerCapExtAttributeStorage *
constructmlir::spirv::detail::VerCapExtAttributeStorage73 construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
74 return new (allocator.allocate<VerCapExtAttributeStorage>())
75 VerCapExtAttributeStorage(std::get<0>(key), std::get<1>(key),
76 std::get<2>(key));
77 }
78
79 Attribute version;
80 Attribute capabilities;
81 Attribute extensions;
82 };
83
84 struct TargetEnvAttributeStorage : public AttributeStorage {
85 using KeyTy = std::tuple<Attribute, Vendor, DeviceType, uint32_t, Attribute>;
86
TargetEnvAttributeStoragemlir::spirv::detail::TargetEnvAttributeStorage87 TargetEnvAttributeStorage(Attribute triple, Vendor vendorID,
88 DeviceType deviceType, uint32_t deviceID,
89 Attribute limits)
90 : triple(triple), limits(limits), vendorID(vendorID),
91 deviceType(deviceType), deviceID(deviceID) {}
92
operator ==mlir::spirv::detail::TargetEnvAttributeStorage93 bool operator==(const KeyTy &key) const {
94 return key ==
95 std::make_tuple(triple, vendorID, deviceType, deviceID, limits);
96 }
97
98 static TargetEnvAttributeStorage *
constructmlir::spirv::detail::TargetEnvAttributeStorage99 construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
100 return new (allocator.allocate<TargetEnvAttributeStorage>())
101 TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key),
102 std::get<2>(key), std::get<3>(key),
103 std::get<4>(key));
104 }
105
106 Attribute triple;
107 Attribute limits;
108 Vendor vendorID;
109 DeviceType deviceType;
110 uint32_t deviceID;
111 };
112 } // namespace detail
113 } // namespace spirv
114 } // namespace mlir
115
116 //===----------------------------------------------------------------------===//
117 // InterfaceVarABIAttr
118 //===----------------------------------------------------------------------===//
119
120 spirv::InterfaceVarABIAttr
get(uint32_t descriptorSet,uint32_t binding,Optional<spirv::StorageClass> storageClass,MLIRContext * context)121 spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding,
122 Optional<spirv::StorageClass> storageClass,
123 MLIRContext *context) {
124 Builder b(context);
125 auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet);
126 auto bindingAttr = b.getI32IntegerAttr(binding);
127 auto storageClassAttr =
128 storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass))
129 : IntegerAttr();
130 return get(descriptorSetAttr, bindingAttr, storageClassAttr);
131 }
132
133 spirv::InterfaceVarABIAttr
get(IntegerAttr descriptorSet,IntegerAttr binding,IntegerAttr storageClass)134 spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
135 IntegerAttr storageClass) {
136 assert(descriptorSet && binding);
137 MLIRContext *context = descriptorSet.getContext();
138 return Base::get(context, descriptorSet, binding, storageClass);
139 }
140
getKindName()141 StringRef spirv::InterfaceVarABIAttr::getKindName() {
142 return "interface_var_abi";
143 }
144
getBinding()145 uint32_t spirv::InterfaceVarABIAttr::getBinding() {
146 return getImpl()->binding.cast<IntegerAttr>().getInt();
147 }
148
getDescriptorSet()149 uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
150 return getImpl()->descriptorSet.cast<IntegerAttr>().getInt();
151 }
152
getStorageClass()153 Optional<spirv::StorageClass> spirv::InterfaceVarABIAttr::getStorageClass() {
154 if (getImpl()->storageClass)
155 return static_cast<spirv::StorageClass>(
156 getImpl()->storageClass.cast<IntegerAttr>().getValue().getZExtValue());
157 return llvm::None;
158 }
159
verify(function_ref<InFlightDiagnostic ()> emitError,IntegerAttr descriptorSet,IntegerAttr binding,IntegerAttr storageClass)160 LogicalResult spirv::InterfaceVarABIAttr::verify(
161 function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
162 IntegerAttr binding, IntegerAttr storageClass) {
163 if (!descriptorSet.getType().isSignlessInteger(32))
164 return emitError() << "expected 32-bit integer for descriptor set";
165
166 if (!binding.getType().isSignlessInteger(32))
167 return emitError() << "expected 32-bit integer for binding";
168
169 if (storageClass) {
170 if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) {
171 auto storageClassValue =
172 spirv::symbolizeStorageClass(storageClassAttr.getInt());
173 if (!storageClassValue)
174 return emitError() << "unknown storage class";
175 } else {
176 return emitError() << "expected valid storage class";
177 }
178 }
179
180 return success();
181 }
182
183 //===----------------------------------------------------------------------===//
184 // VerCapExtAttr
185 //===----------------------------------------------------------------------===//
186
get(spirv::Version version,ArrayRef<spirv::Capability> capabilities,ArrayRef<spirv::Extension> extensions,MLIRContext * context)187 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(
188 spirv::Version version, ArrayRef<spirv::Capability> capabilities,
189 ArrayRef<spirv::Extension> extensions, MLIRContext *context) {
190 Builder b(context);
191
192 auto versionAttr = b.getI32IntegerAttr(static_cast<uint32_t>(version));
193
194 SmallVector<Attribute, 4> capAttrs;
195 capAttrs.reserve(capabilities.size());
196 for (spirv::Capability cap : capabilities)
197 capAttrs.push_back(b.getI32IntegerAttr(static_cast<uint32_t>(cap)));
198
199 SmallVector<Attribute, 4> extAttrs;
200 extAttrs.reserve(extensions.size());
201 for (spirv::Extension ext : extensions)
202 extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext)));
203
204 return get(versionAttr, b.getArrayAttr(capAttrs), b.getArrayAttr(extAttrs));
205 }
206
get(IntegerAttr version,ArrayAttr capabilities,ArrayAttr extensions)207 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version,
208 ArrayAttr capabilities,
209 ArrayAttr extensions) {
210 assert(version && capabilities && extensions);
211 MLIRContext *context = version.getContext();
212 return Base::get(context, version, capabilities, extensions);
213 }
214
getKindName()215 StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
216
getVersion()217 spirv::Version spirv::VerCapExtAttr::getVersion() {
218 return static_cast<spirv::Version>(
219 getImpl()->version.cast<IntegerAttr>().getValue().getZExtValue());
220 }
221
ext_iterator(ArrayAttr::iterator it)222 spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it)
223 : llvm::mapped_iterator<ArrayAttr::iterator,
224 spirv::Extension (*)(Attribute)>(
225 it, [](Attribute attr) {
226 return *symbolizeExtension(attr.cast<StringAttr>().getValue());
227 }) {}
228
getExtensions()229 spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
230 auto range = getExtensionsAttr().getValue();
231 return {ext_iterator(range.begin()), ext_iterator(range.end())};
232 }
233
getExtensionsAttr()234 ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() {
235 return getImpl()->extensions.cast<ArrayAttr>();
236 }
237
cap_iterator(ArrayAttr::iterator it)238 spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
239 : llvm::mapped_iterator<ArrayAttr::iterator,
240 spirv::Capability (*)(Attribute)>(
241 it, [](Attribute attr) {
242 return *symbolizeCapability(
243 attr.cast<IntegerAttr>().getValue().getZExtValue());
244 }) {}
245
getCapabilities()246 spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
247 auto range = getCapabilitiesAttr().getValue();
248 return {cap_iterator(range.begin()), cap_iterator(range.end())};
249 }
250
getCapabilitiesAttr()251 ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
252 return getImpl()->capabilities.cast<ArrayAttr>();
253 }
254
255 LogicalResult
verify(function_ref<InFlightDiagnostic ()> emitError,IntegerAttr version,ArrayAttr capabilities,ArrayAttr extensions)256 spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
257 IntegerAttr version, ArrayAttr capabilities,
258 ArrayAttr extensions) {
259 if (!version.getType().isSignlessInteger(32))
260 return emitError() << "expected 32-bit integer for version";
261
262 if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
263 if (auto intAttr = attr.dyn_cast<IntegerAttr>())
264 if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))
265 return true;
266 return false;
267 }))
268 return emitError() << "unknown capability in capability list";
269
270 if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
271 if (auto strAttr = attr.dyn_cast<StringAttr>())
272 if (spirv::symbolizeExtension(strAttr.getValue()))
273 return true;
274 return false;
275 }))
276 return emitError() << "unknown extension in extension list";
277
278 return success();
279 }
280
281 //===----------------------------------------------------------------------===//
282 // TargetEnvAttr
283 //===----------------------------------------------------------------------===//
284
get(spirv::VerCapExtAttr triple,Vendor vendorID,DeviceType deviceType,uint32_t deviceID,ResourceLimitsAttr limits)285 spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
286 Vendor vendorID,
287 DeviceType deviceType,
288 uint32_t deviceID,
289 ResourceLimitsAttr limits) {
290 assert(triple && limits && "expected valid triple and limits");
291 MLIRContext *context = triple.getContext();
292 return Base::get(context, triple, vendorID, deviceType, deviceID, limits);
293 }
294
getKindName()295 StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
296
getTripleAttr() const297 spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const {
298 return getImpl()->triple.cast<spirv::VerCapExtAttr>();
299 }
300
getVersion() const301 spirv::Version spirv::TargetEnvAttr::getVersion() const {
302 return getTripleAttr().getVersion();
303 }
304
getExtensions()305 spirv::VerCapExtAttr::ext_range spirv::TargetEnvAttr::getExtensions() {
306 return getTripleAttr().getExtensions();
307 }
308
getExtensionsAttr()309 ArrayAttr spirv::TargetEnvAttr::getExtensionsAttr() {
310 return getTripleAttr().getExtensionsAttr();
311 }
312
getCapabilities()313 spirv::VerCapExtAttr::cap_range spirv::TargetEnvAttr::getCapabilities() {
314 return getTripleAttr().getCapabilities();
315 }
316
getCapabilitiesAttr()317 ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() {
318 return getTripleAttr().getCapabilitiesAttr();
319 }
320
getVendorID() const321 spirv::Vendor spirv::TargetEnvAttr::getVendorID() const {
322 return getImpl()->vendorID;
323 }
324
getDeviceType() const325 spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() const {
326 return getImpl()->deviceType;
327 }
328
getDeviceID() const329 uint32_t spirv::TargetEnvAttr::getDeviceID() const {
330 return getImpl()->deviceID;
331 }
332
getResourceLimits() const333 spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
334 return getImpl()->limits.cast<spirv::ResourceLimitsAttr>();
335 }
336
337 //===----------------------------------------------------------------------===//
338 // ODS Generated Attributes
339 //===----------------------------------------------------------------------===//
340
341 #define GET_ATTRDEF_CLASSES
342 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"
343
344 //===----------------------------------------------------------------------===//
345 // Attribute Parsing
346 //===----------------------------------------------------------------------===//
347
348 /// Parses a comma-separated list of keywords, invokes `processKeyword` on each
349 /// of the parsed keyword, and returns failure if any error occurs.
350 static ParseResult
parseKeywordList(DialectAsmParser & parser,function_ref<LogicalResult (SMLoc,StringRef)> processKeyword)351 parseKeywordList(DialectAsmParser &parser,
352 function_ref<LogicalResult(SMLoc, StringRef)> processKeyword) {
353 if (parser.parseLSquare())
354 return failure();
355
356 // Special case for empty list.
357 if (succeeded(parser.parseOptionalRSquare()))
358 return success();
359
360 // Keep parsing the keyword and an optional comma following it. If the comma
361 // is successfully parsed, then we have more keywords to parse.
362 if (failed(parser.parseCommaSeparatedList([&]() {
363 auto loc = parser.getCurrentLocation();
364 StringRef keyword;
365 if (parser.parseKeyword(&keyword) ||
366 failed(processKeyword(loc, keyword)))
367 return failure();
368 return success();
369 })))
370 return failure();
371 return parser.parseRSquare();
372 }
373
374 /// Parses a spirv::InterfaceVarABIAttr.
parseInterfaceVarABIAttr(DialectAsmParser & parser)375 static Attribute parseInterfaceVarABIAttr(DialectAsmParser &parser) {
376 if (parser.parseLess())
377 return {};
378
379 Builder &builder = parser.getBuilder();
380
381 if (parser.parseLParen())
382 return {};
383
384 IntegerAttr descriptorSetAttr;
385 {
386 auto loc = parser.getCurrentLocation();
387 uint32_t descriptorSet = 0;
388 auto descriptorSetParseResult = parser.parseOptionalInteger(descriptorSet);
389
390 if (!descriptorSetParseResult.hasValue() ||
391 failed(*descriptorSetParseResult)) {
392 parser.emitError(loc, "missing descriptor set");
393 return {};
394 }
395 descriptorSetAttr = builder.getI32IntegerAttr(descriptorSet);
396 }
397
398 if (parser.parseComma())
399 return {};
400
401 IntegerAttr bindingAttr;
402 {
403 auto loc = parser.getCurrentLocation();
404 uint32_t binding = 0;
405 auto bindingParseResult = parser.parseOptionalInteger(binding);
406
407 if (!bindingParseResult.hasValue() || failed(*bindingParseResult)) {
408 parser.emitError(loc, "missing binding");
409 return {};
410 }
411 bindingAttr = builder.getI32IntegerAttr(binding);
412 }
413
414 if (parser.parseRParen())
415 return {};
416
417 IntegerAttr storageClassAttr;
418 {
419 if (succeeded(parser.parseOptionalComma())) {
420 auto loc = parser.getCurrentLocation();
421 StringRef storageClass;
422 if (parser.parseKeyword(&storageClass))
423 return {};
424
425 if (auto storageClassSymbol =
426 spirv::symbolizeStorageClass(storageClass)) {
427 storageClassAttr = builder.getI32IntegerAttr(
428 static_cast<uint32_t>(*storageClassSymbol));
429 } else {
430 parser.emitError(loc, "unknown storage class: ") << storageClass;
431 return {};
432 }
433 }
434 }
435
436 if (parser.parseGreater())
437 return {};
438
439 return spirv::InterfaceVarABIAttr::get(descriptorSetAttr, bindingAttr,
440 storageClassAttr);
441 }
442
parseVerCapExtAttr(DialectAsmParser & parser)443 static Attribute parseVerCapExtAttr(DialectAsmParser &parser) {
444 if (parser.parseLess())
445 return {};
446
447 Builder &builder = parser.getBuilder();
448
449 IntegerAttr versionAttr;
450 {
451 auto loc = parser.getCurrentLocation();
452 StringRef version;
453 if (parser.parseKeyword(&version) || parser.parseComma())
454 return {};
455
456 if (auto versionSymbol = spirv::symbolizeVersion(version)) {
457 versionAttr =
458 builder.getI32IntegerAttr(static_cast<uint32_t>(*versionSymbol));
459 } else {
460 parser.emitError(loc, "unknown version: ") << version;
461 return {};
462 }
463 }
464
465 ArrayAttr capabilitiesAttr;
466 {
467 SmallVector<Attribute, 4> capabilities;
468 SMLoc errorloc;
469 StringRef errorKeyword;
470
471 auto processCapability = [&](SMLoc loc, StringRef capability) {
472 if (auto capSymbol = spirv::symbolizeCapability(capability)) {
473 capabilities.push_back(
474 builder.getI32IntegerAttr(static_cast<uint32_t>(*capSymbol)));
475 return success();
476 }
477 return errorloc = loc, errorKeyword = capability, failure();
478 };
479 if (parseKeywordList(parser, processCapability) || parser.parseComma()) {
480 if (!errorKeyword.empty())
481 parser.emitError(errorloc, "unknown capability: ") << errorKeyword;
482 return {};
483 }
484
485 capabilitiesAttr = builder.getArrayAttr(capabilities);
486 }
487
488 ArrayAttr extensionsAttr;
489 {
490 SmallVector<Attribute, 1> extensions;
491 SMLoc errorloc;
492 StringRef errorKeyword;
493
494 auto processExtension = [&](SMLoc loc, StringRef extension) {
495 if (spirv::symbolizeExtension(extension)) {
496 extensions.push_back(builder.getStringAttr(extension));
497 return success();
498 }
499 return errorloc = loc, errorKeyword = extension, failure();
500 };
501 if (parseKeywordList(parser, processExtension)) {
502 if (!errorKeyword.empty())
503 parser.emitError(errorloc, "unknown extension: ") << errorKeyword;
504 return {};
505 }
506
507 extensionsAttr = builder.getArrayAttr(extensions);
508 }
509
510 if (parser.parseGreater())
511 return {};
512
513 return spirv::VerCapExtAttr::get(versionAttr, capabilitiesAttr,
514 extensionsAttr);
515 }
516
517 /// Parses a spirv::TargetEnvAttr.
parseTargetEnvAttr(DialectAsmParser & parser)518 static Attribute parseTargetEnvAttr(DialectAsmParser &parser) {
519 if (parser.parseLess())
520 return {};
521
522 spirv::VerCapExtAttr tripleAttr;
523 if (parser.parseAttribute(tripleAttr) || parser.parseComma())
524 return {};
525
526 // Parse [vendor[:device-type[:device-id]]]
527 Vendor vendorID = Vendor::Unknown;
528 DeviceType deviceType = DeviceType::Unknown;
529 uint32_t deviceID = spirv::TargetEnvAttr::kUnknownDeviceID;
530 {
531 auto loc = parser.getCurrentLocation();
532 StringRef vendorStr;
533 if (succeeded(parser.parseOptionalKeyword(&vendorStr))) {
534 if (auto vendorSymbol = spirv::symbolizeVendor(vendorStr)) {
535 vendorID = *vendorSymbol;
536 } else {
537 parser.emitError(loc, "unknown vendor: ") << vendorStr;
538 }
539
540 if (succeeded(parser.parseOptionalColon())) {
541 loc = parser.getCurrentLocation();
542 StringRef deviceTypeStr;
543 if (parser.parseKeyword(&deviceTypeStr))
544 return {};
545 if (auto deviceTypeSymbol = spirv::symbolizeDeviceType(deviceTypeStr)) {
546 deviceType = *deviceTypeSymbol;
547 } else {
548 parser.emitError(loc, "unknown device type: ") << deviceTypeStr;
549 }
550
551 if (succeeded(parser.parseOptionalColon())) {
552 loc = parser.getCurrentLocation();
553 if (parser.parseInteger(deviceID))
554 return {};
555 }
556 }
557 if (parser.parseComma())
558 return {};
559 }
560 }
561
562 ResourceLimitsAttr limitsAttr;
563 if (parser.parseAttribute(limitsAttr) || parser.parseGreater())
564 return {};
565
566 return spirv::TargetEnvAttr::get(tripleAttr, vendorID, deviceType, deviceID,
567 limitsAttr);
568 }
569
parseAttribute(DialectAsmParser & parser,Type type) const570 Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
571 Type type) const {
572 // SPIR-V attributes are dictionaries so they do not have type.
573 if (type) {
574 parser.emitError(parser.getNameLoc(), "unexpected type");
575 return {};
576 }
577
578 // Parse the kind keyword first.
579 StringRef attrKind;
580 Attribute attr;
581 OptionalParseResult result =
582 generatedAttributeParser(parser, &attrKind, type, attr);
583 if (result.hasValue())
584 return attr;
585
586 if (attrKind == spirv::TargetEnvAttr::getKindName())
587 return parseTargetEnvAttr(parser);
588 if (attrKind == spirv::VerCapExtAttr::getKindName())
589 return parseVerCapExtAttr(parser);
590 if (attrKind == spirv::InterfaceVarABIAttr::getKindName())
591 return parseInterfaceVarABIAttr(parser);
592
593 parser.emitError(parser.getNameLoc(), "unknown SPIR-V attribute kind: ")
594 << attrKind;
595 return {};
596 }
597
598 //===----------------------------------------------------------------------===//
599 // Attribute Printing
600 //===----------------------------------------------------------------------===//
601
print(spirv::VerCapExtAttr triple,DialectAsmPrinter & printer)602 static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer) {
603 auto &os = printer.getStream();
604 printer << spirv::VerCapExtAttr::getKindName() << "<"
605 << spirv::stringifyVersion(triple.getVersion()) << ", [";
606 llvm::interleaveComma(
607 triple.getCapabilities(), os,
608 [&](spirv::Capability cap) { os << spirv::stringifyCapability(cap); });
609 printer << "], [";
610 llvm::interleaveComma(triple.getExtensionsAttr(), os, [&](Attribute attr) {
611 os << attr.cast<StringAttr>().getValue();
612 });
613 printer << "]>";
614 }
615
print(spirv::TargetEnvAttr targetEnv,DialectAsmPrinter & printer)616 static void print(spirv::TargetEnvAttr targetEnv, DialectAsmPrinter &printer) {
617 printer << spirv::TargetEnvAttr::getKindName() << "<#spv.";
618 print(targetEnv.getTripleAttr(), printer);
619 spirv::Vendor vendorID = targetEnv.getVendorID();
620 spirv::DeviceType deviceType = targetEnv.getDeviceType();
621 uint32_t deviceID = targetEnv.getDeviceID();
622 if (vendorID != spirv::Vendor::Unknown) {
623 printer << ", " << spirv::stringifyVendor(vendorID);
624 if (deviceType != spirv::DeviceType::Unknown) {
625 printer << ":" << spirv::stringifyDeviceType(deviceType);
626 if (deviceID != spirv::TargetEnvAttr::kUnknownDeviceID)
627 printer << ":" << deviceID;
628 }
629 }
630 printer << ", " << targetEnv.getResourceLimits() << ">";
631 }
632
print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,DialectAsmPrinter & printer)633 static void print(spirv::InterfaceVarABIAttr interfaceVarABIAttr,
634 DialectAsmPrinter &printer) {
635 printer << spirv::InterfaceVarABIAttr::getKindName() << "<("
636 << interfaceVarABIAttr.getDescriptorSet() << ", "
637 << interfaceVarABIAttr.getBinding() << ")";
638 auto storageClass = interfaceVarABIAttr.getStorageClass();
639 if (storageClass)
640 printer << ", " << spirv::stringifyStorageClass(*storageClass);
641 printer << ">";
642 }
643
printAttribute(Attribute attr,DialectAsmPrinter & printer) const644 void SPIRVDialect::printAttribute(Attribute attr,
645 DialectAsmPrinter &printer) const {
646 if (succeeded(generatedAttributePrinter(attr, printer)))
647 return;
648
649 if (auto targetEnv = attr.dyn_cast<TargetEnvAttr>())
650 print(targetEnv, printer);
651 else if (auto vceAttr = attr.dyn_cast<VerCapExtAttr>())
652 print(vceAttr, printer);
653 else if (auto interfaceVarABIAttr = attr.dyn_cast<InterfaceVarABIAttr>())
654 print(interfaceVarABIAttr, printer);
655 else
656 llvm_unreachable("unhandled SPIR-V attribute kind");
657 }
658
659 //===----------------------------------------------------------------------===//
660 // SPIR-V Dialect
661 //===----------------------------------------------------------------------===//
662
registerAttributes()663 void spirv::SPIRVDialect::registerAttributes() {
664 addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
665 addAttributes<
666 #define GET_ATTRDEF_LIST
667 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.cpp.inc"
668 >();
669 }
670