1 //===- SPIRVAttributes.cpp - SPIR-V attribute definitions -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" 10 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 11 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 12 #include "mlir/IR/Builders.h" 13 14 using namespace mlir; 15 16 //===----------------------------------------------------------------------===// 17 // TableGen'erated attribute utility functions 18 //===----------------------------------------------------------------------===// 19 20 namespace mlir { 21 namespace spirv { 22 #include "mlir/Dialect/SPIRV/IR/SPIRVAttrUtils.inc" 23 } // namespace spirv 24 } // namespace mlir 25 26 //===----------------------------------------------------------------------===// 27 // DictionaryDict derived attributes 28 //===----------------------------------------------------------------------===// 29 30 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.cpp.inc" 31 32 namespace mlir { 33 34 //===----------------------------------------------------------------------===// 35 // Attribute storage classes 36 //===----------------------------------------------------------------------===// 37 38 namespace spirv { 39 namespace detail { 40 41 struct InterfaceVarABIAttributeStorage : public AttributeStorage { 42 using KeyTy = std::tuple<Attribute, Attribute, Attribute>; 43 44 InterfaceVarABIAttributeStorage(Attribute descriptorSet, Attribute binding, 45 Attribute storageClass) 46 : descriptorSet(descriptorSet), binding(binding), 47 storageClass(storageClass) {} 48 49 bool operator==(const KeyTy &key) const { 50 return std::get<0>(key) == descriptorSet && std::get<1>(key) == binding && 51 std::get<2>(key) == storageClass; 52 } 53 54 static InterfaceVarABIAttributeStorage * 55 construct(AttributeStorageAllocator &allocator, const KeyTy &key) { 56 return new (allocator.allocate<InterfaceVarABIAttributeStorage>()) 57 InterfaceVarABIAttributeStorage(std::get<0>(key), std::get<1>(key), 58 std::get<2>(key)); 59 } 60 61 Attribute descriptorSet; 62 Attribute binding; 63 Attribute storageClass; 64 }; 65 66 struct VerCapExtAttributeStorage : public AttributeStorage { 67 using KeyTy = std::tuple<Attribute, Attribute, Attribute>; 68 69 VerCapExtAttributeStorage(Attribute version, Attribute capabilities, 70 Attribute extensions) 71 : version(version), capabilities(capabilities), extensions(extensions) {} 72 73 bool operator==(const KeyTy &key) const { 74 return std::get<0>(key) == version && std::get<1>(key) == capabilities && 75 std::get<2>(key) == extensions; 76 } 77 78 static VerCapExtAttributeStorage * 79 construct(AttributeStorageAllocator &allocator, const KeyTy &key) { 80 return new (allocator.allocate<VerCapExtAttributeStorage>()) 81 VerCapExtAttributeStorage(std::get<0>(key), std::get<1>(key), 82 std::get<2>(key)); 83 } 84 85 Attribute version; 86 Attribute capabilities; 87 Attribute extensions; 88 }; 89 90 struct TargetEnvAttributeStorage : public AttributeStorage { 91 using KeyTy = std::tuple<Attribute, Vendor, DeviceType, uint32_t, Attribute>; 92 93 TargetEnvAttributeStorage(Attribute triple, Vendor vendorID, 94 DeviceType deviceType, uint32_t deviceID, 95 Attribute limits) 96 : triple(triple), limits(limits), vendorID(vendorID), 97 deviceType(deviceType), deviceID(deviceID) {} 98 99 bool operator==(const KeyTy &key) const { 100 return key == 101 std::make_tuple(triple, vendorID, deviceType, deviceID, limits); 102 } 103 104 static TargetEnvAttributeStorage * 105 construct(AttributeStorageAllocator &allocator, const KeyTy &key) { 106 return new (allocator.allocate<TargetEnvAttributeStorage>()) 107 TargetEnvAttributeStorage(std::get<0>(key), std::get<1>(key), 108 std::get<2>(key), std::get<3>(key), 109 std::get<4>(key)); 110 } 111 112 Attribute triple; 113 Attribute limits; 114 Vendor vendorID; 115 DeviceType deviceType; 116 uint32_t deviceID; 117 }; 118 } // namespace detail 119 } // namespace spirv 120 } // namespace mlir 121 122 //===----------------------------------------------------------------------===// 123 // InterfaceVarABIAttr 124 //===----------------------------------------------------------------------===// 125 126 spirv::InterfaceVarABIAttr 127 spirv::InterfaceVarABIAttr::get(uint32_t descriptorSet, uint32_t binding, 128 Optional<spirv::StorageClass> storageClass, 129 MLIRContext *context) { 130 Builder b(context); 131 auto descriptorSetAttr = b.getI32IntegerAttr(descriptorSet); 132 auto bindingAttr = b.getI32IntegerAttr(binding); 133 auto storageClassAttr = 134 storageClass ? b.getI32IntegerAttr(static_cast<uint32_t>(*storageClass)) 135 : IntegerAttr(); 136 return get(descriptorSetAttr, bindingAttr, storageClassAttr); 137 } 138 139 spirv::InterfaceVarABIAttr 140 spirv::InterfaceVarABIAttr::get(IntegerAttr descriptorSet, IntegerAttr binding, 141 IntegerAttr storageClass) { 142 assert(descriptorSet && binding); 143 MLIRContext *context = descriptorSet.getContext(); 144 return Base::get(context, descriptorSet, binding, storageClass); 145 } 146 147 StringRef spirv::InterfaceVarABIAttr::getKindName() { 148 return "interface_var_abi"; 149 } 150 151 uint32_t spirv::InterfaceVarABIAttr::getBinding() { 152 return getImpl()->binding.cast<IntegerAttr>().getInt(); 153 } 154 155 uint32_t spirv::InterfaceVarABIAttr::getDescriptorSet() { 156 return getImpl()->descriptorSet.cast<IntegerAttr>().getInt(); 157 } 158 159 Optional<spirv::StorageClass> spirv::InterfaceVarABIAttr::getStorageClass() { 160 if (getImpl()->storageClass) 161 return static_cast<spirv::StorageClass>( 162 getImpl()->storageClass.cast<IntegerAttr>().getValue().getZExtValue()); 163 return llvm::None; 164 } 165 166 LogicalResult spirv::InterfaceVarABIAttr::verify( 167 function_ref<InFlightDiagnostic()> emitError, IntegerAttr descriptorSet, 168 IntegerAttr binding, IntegerAttr storageClass) { 169 if (!descriptorSet.getType().isSignlessInteger(32)) 170 return emitError() << "expected 32-bit integer for descriptor set"; 171 172 if (!binding.getType().isSignlessInteger(32)) 173 return emitError() << "expected 32-bit integer for binding"; 174 175 if (storageClass) { 176 if (auto storageClassAttr = storageClass.cast<IntegerAttr>()) { 177 auto storageClassValue = 178 spirv::symbolizeStorageClass(storageClassAttr.getInt()); 179 if (!storageClassValue) 180 return emitError() << "unknown storage class"; 181 } else { 182 return emitError() << "expected valid storage class"; 183 } 184 } 185 186 return success(); 187 } 188 189 //===----------------------------------------------------------------------===// 190 // VerCapExtAttr 191 //===----------------------------------------------------------------------===// 192 193 spirv::VerCapExtAttr spirv::VerCapExtAttr::get( 194 spirv::Version version, ArrayRef<spirv::Capability> capabilities, 195 ArrayRef<spirv::Extension> extensions, MLIRContext *context) { 196 Builder b(context); 197 198 auto versionAttr = b.getI32IntegerAttr(static_cast<uint32_t>(version)); 199 200 SmallVector<Attribute, 4> capAttrs; 201 capAttrs.reserve(capabilities.size()); 202 for (spirv::Capability cap : capabilities) 203 capAttrs.push_back(b.getI32IntegerAttr(static_cast<uint32_t>(cap))); 204 205 SmallVector<Attribute, 4> extAttrs; 206 extAttrs.reserve(extensions.size()); 207 for (spirv::Extension ext : extensions) 208 extAttrs.push_back(b.getStringAttr(spirv::stringifyExtension(ext))); 209 210 return get(versionAttr, b.getArrayAttr(capAttrs), b.getArrayAttr(extAttrs)); 211 } 212 213 spirv::VerCapExtAttr spirv::VerCapExtAttr::get(IntegerAttr version, 214 ArrayAttr capabilities, 215 ArrayAttr extensions) { 216 assert(version && capabilities && extensions); 217 MLIRContext *context = version.getContext(); 218 return Base::get(context, version, capabilities, extensions); 219 } 220 221 StringRef spirv::VerCapExtAttr::getKindName() { return "vce"; } 222 223 spirv::Version spirv::VerCapExtAttr::getVersion() { 224 return static_cast<spirv::Version>( 225 getImpl()->version.cast<IntegerAttr>().getValue().getZExtValue()); 226 } 227 228 spirv::VerCapExtAttr::ext_iterator::ext_iterator(ArrayAttr::iterator it) 229 : llvm::mapped_iterator<ArrayAttr::iterator, 230 spirv::Extension (*)(Attribute)>( 231 it, [](Attribute attr) { 232 return *symbolizeExtension(attr.cast<StringAttr>().getValue()); 233 }) {} 234 235 spirv::VerCapExtAttr::ext_range spirv::VerCapExtAttr::getExtensions() { 236 auto range = getExtensionsAttr().getValue(); 237 return {ext_iterator(range.begin()), ext_iterator(range.end())}; 238 } 239 240 ArrayAttr spirv::VerCapExtAttr::getExtensionsAttr() { 241 return getImpl()->extensions.cast<ArrayAttr>(); 242 } 243 244 spirv::VerCapExtAttr::cap_iterator::cap_iterator(ArrayAttr::iterator it) 245 : llvm::mapped_iterator<ArrayAttr::iterator, 246 spirv::Capability (*)(Attribute)>( 247 it, [](Attribute attr) { 248 return *symbolizeCapability( 249 attr.cast<IntegerAttr>().getValue().getZExtValue()); 250 }) {} 251 252 spirv::VerCapExtAttr::cap_range spirv::VerCapExtAttr::getCapabilities() { 253 auto range = getCapabilitiesAttr().getValue(); 254 return {cap_iterator(range.begin()), cap_iterator(range.end())}; 255 } 256 257 ArrayAttr spirv::VerCapExtAttr::getCapabilitiesAttr() { 258 return getImpl()->capabilities.cast<ArrayAttr>(); 259 } 260 261 LogicalResult 262 spirv::VerCapExtAttr::verify(function_ref<InFlightDiagnostic()> emitError, 263 IntegerAttr version, ArrayAttr capabilities, 264 ArrayAttr extensions) { 265 if (!version.getType().isSignlessInteger(32)) 266 return emitError() << "expected 32-bit integer for version"; 267 268 if (!llvm::all_of(capabilities.getValue(), [](Attribute attr) { 269 if (auto intAttr = attr.dyn_cast<IntegerAttr>()) 270 if (spirv::symbolizeCapability(intAttr.getValue().getZExtValue())) 271 return true; 272 return false; 273 })) 274 return emitError() << "unknown capability in capability list"; 275 276 if (!llvm::all_of(extensions.getValue(), [](Attribute attr) { 277 if (auto strAttr = attr.dyn_cast<StringAttr>()) 278 if (spirv::symbolizeExtension(strAttr.getValue())) 279 return true; 280 return false; 281 })) 282 return emitError() << "unknown extension in extension list"; 283 284 return success(); 285 } 286 287 //===----------------------------------------------------------------------===// 288 // TargetEnvAttr 289 //===----------------------------------------------------------------------===// 290 291 spirv::TargetEnvAttr spirv::TargetEnvAttr::get(spirv::VerCapExtAttr triple, 292 Vendor vendorID, 293 DeviceType deviceType, 294 uint32_t deviceID, 295 DictionaryAttr limits) { 296 assert(triple && limits && "expected valid triple and limits"); 297 MLIRContext *context = triple.getContext(); 298 return Base::get(context, triple, vendorID, deviceType, deviceID, limits); 299 } 300 301 StringRef spirv::TargetEnvAttr::getKindName() { return "target_env"; } 302 303 spirv::VerCapExtAttr spirv::TargetEnvAttr::getTripleAttr() const { 304 return getImpl()->triple.cast<spirv::VerCapExtAttr>(); 305 } 306 307 spirv::Version spirv::TargetEnvAttr::getVersion() const { 308 return getTripleAttr().getVersion(); 309 } 310 311 spirv::VerCapExtAttr::ext_range spirv::TargetEnvAttr::getExtensions() { 312 return getTripleAttr().getExtensions(); 313 } 314 315 ArrayAttr spirv::TargetEnvAttr::getExtensionsAttr() { 316 return getTripleAttr().getExtensionsAttr(); 317 } 318 319 spirv::VerCapExtAttr::cap_range spirv::TargetEnvAttr::getCapabilities() { 320 return getTripleAttr().getCapabilities(); 321 } 322 323 ArrayAttr spirv::TargetEnvAttr::getCapabilitiesAttr() { 324 return getTripleAttr().getCapabilitiesAttr(); 325 } 326 327 spirv::Vendor spirv::TargetEnvAttr::getVendorID() const { 328 return getImpl()->vendorID; 329 } 330 331 spirv::DeviceType spirv::TargetEnvAttr::getDeviceType() const { 332 return getImpl()->deviceType; 333 } 334 335 uint32_t spirv::TargetEnvAttr::getDeviceID() const { 336 return getImpl()->deviceID; 337 } 338 339 spirv::ResourceLimitsAttr spirv::TargetEnvAttr::getResourceLimits() const { 340 return getImpl()->limits.cast<spirv::ResourceLimitsAttr>(); 341 } 342 343 LogicalResult 344 spirv::TargetEnvAttr::verify(function_ref<InFlightDiagnostic()> emitError, 345 spirv::VerCapExtAttr /*triple*/, 346 spirv::Vendor /*vendorID*/, 347 spirv::DeviceType /*deviceType*/, 348 uint32_t /*deviceID*/, DictionaryAttr limits) { 349 if (!limits.isa<spirv::ResourceLimitsAttr>()) 350 return emitError() << "expected spirv::ResourceLimitsAttr for limits"; 351 352 return success(); 353 } 354 355 //===----------------------------------------------------------------------===// 356 // SPIR-V Dialect 357 //===----------------------------------------------------------------------===// 358 359 void spirv::SPIRVDialect::registerAttributes() { 360 addAttributes<InterfaceVarABIAttr, TargetEnvAttr, VerCapExtAttr>(); 361 } 362