1 //===- TargetAndABI.cpp - SPIR-V target and ABI utilities -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 10 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 11 #include "mlir/IR/Builders.h" 12 #include "mlir/IR/FunctionSupport.h" 13 #include "mlir/IR/Operation.h" 14 #include "mlir/IR/SymbolTable.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // TargetEnv 20 //===----------------------------------------------------------------------===// 21 22 spirv::TargetEnv::TargetEnv(spirv::TargetEnvAttr targetAttr) 23 : targetAttr(targetAttr) { 24 for (spirv::Extension ext : targetAttr.getExtensions()) 25 givenExtensions.insert(ext); 26 27 // Add extensions implied by the current version. 28 for (spirv::Extension ext : 29 spirv::getImpliedExtensions(targetAttr.getVersion())) 30 givenExtensions.insert(ext); 31 32 for (spirv::Capability cap : targetAttr.getCapabilities()) { 33 givenCapabilities.insert(cap); 34 35 // Add capabilities implied by the current capability. 36 for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap)) 37 givenCapabilities.insert(c); 38 } 39 } 40 41 spirv::Version spirv::TargetEnv::getVersion() const { 42 return targetAttr.getVersion(); 43 } 44 45 bool spirv::TargetEnv::allows(spirv::Capability capability) const { 46 return givenCapabilities.count(capability); 47 } 48 49 Optional<spirv::Capability> 50 spirv::TargetEnv::allows(ArrayRef<spirv::Capability> caps) const { 51 const auto *chosen = llvm::find_if(caps, [this](spirv::Capability cap) { 52 return givenCapabilities.count(cap); 53 }); 54 if (chosen != caps.end()) 55 return *chosen; 56 return llvm::None; 57 } 58 59 bool spirv::TargetEnv::allows(spirv::Extension extension) const { 60 return givenExtensions.count(extension); 61 } 62 63 Optional<spirv::Extension> 64 spirv::TargetEnv::allows(ArrayRef<spirv::Extension> exts) const { 65 const auto *chosen = llvm::find_if(exts, [this](spirv::Extension ext) { 66 return givenExtensions.count(ext); 67 }); 68 if (chosen != exts.end()) 69 return *chosen; 70 return llvm::None; 71 } 72 73 spirv::Vendor spirv::TargetEnv::getVendorID() const { 74 return targetAttr.getVendorID(); 75 } 76 77 spirv::DeviceType spirv::TargetEnv::getDeviceType() const { 78 return targetAttr.getDeviceType(); 79 } 80 81 uint32_t spirv::TargetEnv::getDeviceID() const { 82 return targetAttr.getDeviceID(); 83 } 84 85 spirv::ResourceLimitsAttr spirv::TargetEnv::getResourceLimits() const { 86 return targetAttr.getResourceLimits(); 87 } 88 89 MLIRContext *spirv::TargetEnv::getContext() const { 90 return targetAttr.getContext(); 91 } 92 93 //===----------------------------------------------------------------------===// 94 // Utility functions 95 //===----------------------------------------------------------------------===// 96 97 StringRef spirv::getInterfaceVarABIAttrName() { 98 return "spv.interface_var_abi"; 99 } 100 101 spirv::InterfaceVarABIAttr 102 spirv::getInterfaceVarABIAttr(unsigned descriptorSet, unsigned binding, 103 Optional<spirv::StorageClass> storageClass, 104 MLIRContext *context) { 105 return spirv::InterfaceVarABIAttr::get(descriptorSet, binding, storageClass, 106 context); 107 } 108 109 bool spirv::needsInterfaceVarABIAttrs(spirv::TargetEnvAttr targetAttr) { 110 for (spirv::Capability cap : targetAttr.getCapabilities()) { 111 if (cap == spirv::Capability::Kernel) 112 return false; 113 if (cap == spirv::Capability::Shader) 114 return true; 115 } 116 return false; 117 } 118 119 StringRef spirv::getEntryPointABIAttrName() { return "spv.entry_point_abi"; } 120 121 spirv::EntryPointABIAttr 122 spirv::getEntryPointABIAttr(ArrayRef<int32_t> localSize, MLIRContext *context) { 123 assert(localSize.size() == 3); 124 return spirv::EntryPointABIAttr::get( 125 DenseElementsAttr::get<int32_t>( 126 VectorType::get(3, IntegerType::get(context, 32)), localSize) 127 .cast<DenseIntElementsAttr>(), 128 context); 129 } 130 131 spirv::EntryPointABIAttr spirv::lookupEntryPointABI(Operation *op) { 132 while (op && !op->hasTrait<OpTrait::FunctionLike>()) 133 op = op->getParentOp(); 134 if (!op) 135 return {}; 136 137 if (auto attr = op->getAttrOfType<spirv::EntryPointABIAttr>( 138 spirv::getEntryPointABIAttrName())) 139 return attr; 140 141 return {}; 142 } 143 144 DenseIntElementsAttr spirv::lookupLocalWorkGroupSize(Operation *op) { 145 if (auto entryPoint = spirv::lookupEntryPointABI(op)) 146 return entryPoint.local_size(); 147 148 return {}; 149 } 150 151 spirv::ResourceLimitsAttr 152 spirv::getDefaultResourceLimits(MLIRContext *context) { 153 // All the fields have default values. Here we just provide a nicer way to 154 // construct a default resource limit attribute. 155 return spirv::ResourceLimitsAttr ::get( 156 /*max_compute_shared_memory_size=*/nullptr, 157 /*max_compute_workgroup_invocations=*/nullptr, 158 /*max_compute_workgroup_size=*/nullptr, 159 /*subgroup_size=*/nullptr, 160 /*cooperative_matrix_properties_nv=*/nullptr, context); 161 } 162 163 StringRef spirv::getTargetEnvAttrName() { return "spv.target_env"; } 164 165 spirv::TargetEnvAttr spirv::getDefaultTargetEnv(MLIRContext *context) { 166 auto triple = spirv::VerCapExtAttr::get(spirv::Version::V_1_0, 167 {spirv::Capability::Shader}, 168 ArrayRef<Extension>(), context); 169 return spirv::TargetEnvAttr::get(triple, spirv::Vendor::Unknown, 170 spirv::DeviceType::Unknown, 171 spirv::TargetEnvAttr::kUnknownDeviceID, 172 spirv::getDefaultResourceLimits(context)); 173 } 174 175 spirv::TargetEnvAttr spirv::lookupTargetEnv(Operation *op) { 176 while (op) { 177 op = SymbolTable::getNearestSymbolTable(op); 178 if (!op) 179 break; 180 181 if (auto attr = op->getAttrOfType<spirv::TargetEnvAttr>( 182 spirv::getTargetEnvAttrName())) 183 return attr; 184 185 op = op->getParentOp(); 186 } 187 188 return {}; 189 } 190 191 spirv::TargetEnvAttr spirv::lookupTargetEnvOrDefault(Operation *op) { 192 if (spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(op)) 193 return attr; 194 195 return getDefaultTargetEnv(op->getContext()); 196 } 197 198 spirv::AddressingModel 199 spirv::getAddressingModel(spirv::TargetEnvAttr targetAttr) { 200 for (spirv::Capability cap : targetAttr.getCapabilities()) { 201 // TODO: Physical64 is hard-coded here, but some information should come 202 // from TargetEnvAttr to selected between Physical32 and Physical64. 203 if (cap == Capability::Kernel) 204 return spirv::AddressingModel::Physical64; 205 } 206 // Logical addressing doesn't need any capabilities so return it as default. 207 return spirv::AddressingModel::Logical; 208 } 209 210 FailureOr<spirv::ExecutionModel> 211 spirv::getExecutionModel(spirv::TargetEnvAttr targetAttr) { 212 for (spirv::Capability cap : targetAttr.getCapabilities()) { 213 if (cap == spirv::Capability::Kernel) 214 return spirv::ExecutionModel::Kernel; 215 if (cap == spirv::Capability::Shader) 216 return spirv::ExecutionModel::GLCompute; 217 } 218 return failure(); 219 } 220 221 FailureOr<spirv::MemoryModel> 222 spirv::getMemoryModel(spirv::TargetEnvAttr targetAttr) { 223 for (spirv::Capability cap : targetAttr.getCapabilities()) { 224 if (cap == spirv::Capability::Addresses) 225 return spirv::MemoryModel::OpenCL; 226 if (cap == spirv::Capability::Shader) 227 return spirv::MemoryModel::GLSL450; 228 } 229 return failure(); 230 } 231