1 //===- SPIRVAttributes.cpp - SPIR-V attribute definitions -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
12 #include "mlir/IR/Builders.h"
13 
14 using namespace mlir;
15 
16 //===----------------------------------------------------------------------===//
17 // TableGen'erated attribute utility functions
18 //===----------------------------------------------------------------------===//
19 
20 namespace mlir {
21 namespace spirv {
22 #include "mlir/Dialect/SPIRV/IR/SPIRVAttrUtils.inc"
23 } // namespace spirv
24 } // namespace mlir
25 
26 //===----------------------------------------------------------------------===//
27 // DictionaryDict derived attributes
28 //===----------------------------------------------------------------------===//
29 
30 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.cpp.inc"
31 
32 namespace mlir {
33 
34 //===----------------------------------------------------------------------===//
35 // Attribute storage classes
36 //===----------------------------------------------------------------------===//
37 
38 namespace spirv {
39 namespace detail {
40 
41 struct InterfaceVarABIAttributeStorage : public AttributeStorage {
42   using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
43 
44   InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding,
45                                   Attribute storageClass)
46       : descriptorSet(descriptorSet), binding(binding),
47         storageClass(storageClass) {}
48 
49   bool operator==(const KeyTy &key) const {
50     return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding &&
51            std::get<2>(key) == storageClass;
52   }
53 
54   static InterfaceVarABIAttributeStorage *
55   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
56     return new (allocator.allocate<InterfaceVarABIAttributeStorage>())
57         InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key),
58                                         std::get<2>(key));
59   }
60 
61   Attribute descriptorSet;
62   Attribute binding;
63   Attribute storageClass;
64 };
65 
66 struct VerCapExtAttributeStorage : public AttributeStorage {
67   using KeyTy = std::tuple<Attribute, Attribute, Attribute>;
68 
69   VerCapExtAttributeStorage(Attribute version, Attribute capabilities,
70                             Attribute extensions)
71       : version(version), capabilities(capabilities), extensions(extensions) {}
72 
73   bool operator==(const KeyTy &key) const {
74     return std::get<0>(key) == version && std::get<1>(key) == capabilities &&
75            std::get<2>(key) == extensions;
76   }
77 
78   static VerCapExtAttributeStorage *
79   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
80     return new (allocator.allocate<VerCapExtAttributeStorage>())
81         VerCapExtAttributeStorage(std::get<0>(key), std::get<1>(key),
82                                   std::get<2>(key));
83   }
84 
85   Attribute version;
86   Attribute capabilities;
87   Attribute extensions;
88 };
89 
90 struct TargetEnvAttributeStorage : public AttributeStorage {
91   using KeyTy = std::tuple<Attribute, Vendor, DeviceType, uint32_t, Attribute>;
92 
93   TargetEnvAttributeStorage(Attribute triple, Vendor vendorID,
94                             DeviceType deviceType, uint32_t deviceID,
95                             Attribute limits)
96       : triple(triple), limits(limits), vendorID(vendorID),
97         deviceType(deviceType), deviceID(deviceID) {}
98 
99   bool operator==(const KeyTy &key) const {
100     return key ==
101            std::make_tuple(triple, vendorID, deviceType, deviceID, limits);
102   }
103 
104   static TargetEnvAttributeStorage *
105   construct(AttributeStorageAllocator &allocator, const KeyTy &key) {
106     return new (allocator.allocate<TargetEnvAttributeStorage>())
107         TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key),
108                                   std::get<2>(key), std::get<3>(key),
109                                   std::get<4>(key));
110   }
111 
112   Attribute triple;
113   Attribute limits;
114   Vendor vendorID;
115   DeviceType deviceType;
116   uint32_t deviceID;
117 };
118 } // namespace detail
119 } // namespace spirv
120 } // namespace mlir
121 
122 //===----------------------------------------------------------------------===//
123 // InterfaceVarABIAttr
124 //===----------------------------------------------------------------------===//
125 
126 spirv::InterfaceVarABIAttr
127 spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding,
128                                 Optional<spirv::StorageClass> storageClass,
129                                 MLIRContext *context) {
130   Builder b(context);
131   auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet);
132   auto bindingAttr = b.getI32IntegerAttr(binding);
133   auto storageClassAttr =
134       storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass))
135                    : IntegerAttr();
136   return get(descriptorSetAttr, bindingAttr, storageClassAttr);
137 }
138 
139 spirv::InterfaceVarABIAttr
140 spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding,
141                                 IntegerAttr storageClass) {
142   assert(descriptorSet && binding);
143   MLIRContext *context = descriptorSet.getContext();
144   return Base::get(context, descriptorSet, binding, storageClass);
145 }
146 
147 StringRef spirv::InterfaceVarABIAttr::getKindName() {
148   return "interface_var_abi";
149 }
150 
151 uint32_t spirv::InterfaceVarABIAttr::getBinding() {
152   return getImpl()->binding.cast<IntegerAttr>().getInt();
153 }
154 
155 uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() {
156   return getImpl()->descriptorSet.cast<IntegerAttr>().getInt();
157 }
158 
159 Optional<spirv::StorageClass> spirv::InterfaceVarABIAttr::getStorageClass() {
160   if (getImpl()->storageClass)
161     return static_cast<spirv::StorageClass>(
162         getImpl()->storageClass.cast<IntegerAttr>().getValue().getZExtValue());
163   return llvm::None;
164 }
165 
166 LogicalResult spirv::InterfaceVarABIAttr::verify(
167     function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet,
168     IntegerAttr binding, IntegerAttr storageClass) {
169   if (!descriptorSet.getType().isSignlessInteger(32))
170     return emitError() << "expected 32-bit integer for descriptor set";
171 
172   if (!binding.getType().isSignlessInteger(32))
173     return emitError() << "expected 32-bit integer for binding";
174 
175   if (storageClass) {
176     if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) {
177       auto storageClassValue =
178           spirv::symbolizeStorageClass(storageClassAttr.getInt());
179       if (!storageClassValue)
180         return emitError() << "unknown storage class";
181     } else {
182       return emitError() << "expected valid storage class";
183     }
184   }
185 
186   return success();
187 }
188 
189 //===----------------------------------------------------------------------===//
190 // VerCapExtAttr
191 //===----------------------------------------------------------------------===//
192 
193 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(
194     spirv::Version version, ArrayRef<spirv::Capability> capabilities,
195     ArrayRef<spirv::Extension> extensions, MLIRContext *context) {
196   Builder b(context);
197 
198   auto versionAttr = b.getI32IntegerAttr(static_cast<uint32_t>(version));
199 
200   SmallVector<Attribute, 4> capAttrs;
201   capAttrs.reserve(capabilities.size());
202   for (spirv::Capability cap : capabilities)
203     capAttrs.push_back(b.getI32IntegerAttr(static_cast<uint32_t>(cap)));
204 
205   SmallVector<Attribute, 4> extAttrs;
206   extAttrs.reserve(extensions.size());
207   for (spirv::Extension ext : extensions)
208     extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext)));
209 
210   return get(versionAttr, b.getArrayAttr(capAttrs), b.getArrayAttr(extAttrs));
211 }
212 
213 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version,
214                                                ArrayAttr capabilities,
215                                                ArrayAttr extensions) {
216   assert(version && capabilities && extensions);
217   MLIRContext *context = version.getContext();
218   return Base::get(context, version, capabilities, extensions);
219 }
220 
221 StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; }
222 
223 spirv::Version spirv::VerCapExtAttr::getVersion() {
224   return static_cast<spirv::Version>(
225       getImpl()->version.cast<IntegerAttr>().getValue().getZExtValue());
226 }
227 
228 spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it)
229     : llvm::mapped_iterator<ArrayAttr::iterator,
230                             spirv::Extension (*)(Attribute)>(
231           it, [](Attribute attr) {
232             return *symbolizeExtension(attr.cast<StringAttr>().getValue());
233           }) {}
234 
235 spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() {
236   auto range = getExtensionsAttr().getValue();
237   return {ext_iterator(range.begin()), ext_iterator(range.end())};
238 }
239 
240 ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() {
241   return getImpl()->extensions.cast<ArrayAttr>();
242 }
243 
244 spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it)
245     : llvm::mapped_iterator<ArrayAttr::iterator,
246                             spirv::Capability (*)(Attribute)>(
247           it, [](Attribute attr) {
248             return *symbolizeCapability(
249                 attr.cast<IntegerAttr>().getValue().getZExtValue());
250           }) {}
251 
252 spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() {
253   auto range = getCapabilitiesAttr().getValue();
254   return {cap_iterator(range.begin()), cap_iterator(range.end())};
255 }
256 
257 ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() {
258   return getImpl()->capabilities.cast<ArrayAttr>();
259 }
260 
261 LogicalResult
262 spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError,
263                              IntegerAttr version, ArrayAttr capabilities,
264                              ArrayAttr extensions) {
265   if (!version.getType().isSignlessInteger(32))
266     return emitError() << "expected 32-bit integer for version";
267 
268   if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) {
269         if (auto intAttr = attr.dyn_cast<IntegerAttr>())
270           if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue()))
271             return true;
272         return false;
273       }))
274     return emitError() << "unknown capability in capability list";
275 
276   if (!llvm::all_of(extensions.getValue(), [](Attribute attr) {
277         if (auto strAttr = attr.dyn_cast<StringAttr>())
278           if (spirv::symbolizeExtension(strAttr.getValue()))
279             return true;
280         return false;
281       }))
282     return emitError() << "unknown extension in extension list";
283 
284   return success();
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // TargetEnvAttr
289 //===----------------------------------------------------------------------===//
290 
291 spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple,
292                                                Vendor vendorID,
293                                                DeviceType deviceType,
294                                                uint32_t deviceID,
295                                                DictionaryAttr limits) {
296   assert(triple && limits && "expected valid triple and limits");
297   MLIRContext *context = triple.getContext();
298   return Base::get(context, triple, vendorID, deviceType, deviceID, limits);
299 }
300 
301 StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; }
302 
303 spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const {
304   return getImpl()->triple.cast<spirv::VerCapExtAttr>();
305 }
306 
307 spirv::Version spirv::TargetEnvAttr::getVersion() const {
308   return getTripleAttr().getVersion();
309 }
310 
311 spirv::VerCapExtAttr::ext_range spirv::TargetEnvAttr::getExtensions() {
312   return getTripleAttr().getExtensions();
313 }
314 
315 ArrayAttr spirv::TargetEnvAttr::getExtensionsAttr() {
316   return getTripleAttr().getExtensionsAttr();
317 }
318 
319 spirv::VerCapExtAttr::cap_range spirv::TargetEnvAttr::getCapabilities() {
320   return getTripleAttr().getCapabilities();
321 }
322 
323 ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() {
324   return getTripleAttr().getCapabilitiesAttr();
325 }
326 
327 spirv::Vendor spirv::TargetEnvAttr::getVendorID() const {
328   return getImpl()->vendorID;
329 }
330 
331 spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() const {
332   return getImpl()->deviceType;
333 }
334 
335 uint32_t spirv::TargetEnvAttr::getDeviceID() const {
336   return getImpl()->deviceID;
337 }
338 
339 spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const {
340   return getImpl()->limits.cast<spirv::ResourceLimitsAttr>();
341 }
342 
343 LogicalResult
344 spirv::TargetEnvAttr::verify(function_ref<InFlightDiagnostic()> emitError,
345                              spirv::VerCapExtAttr /*triple*/,
346                              spirv::Vendor /*vendorID*/,
347                              spirv::DeviceType /*deviceType*/,
348                              uint32_t /*deviceID*/, DictionaryAttr limits) {
349   if (!limits.isa<spirv::ResourceLimitsAttr>())
350     return emitError() << "expected spirv::ResourceLimitsAttr for limits";
351 
352   return success();
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // SPIR-V Dialect
357 //===----------------------------------------------------------------------===//
358 
359 void spirv::SPIRVDialect::registerAttributes() {
360   addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>();
361 }
362