1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to deduce minimal version/extension/capability 10 // requirements for a spirv::ModuleOp. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Dialect/SPIRV/Passes.h" 15 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 18 #include "mlir/Dialect/SPIRV/TargetAndABI.h" 19 #include "mlir/IR/Builders.h" 20 #include "mlir/IR/Visitors.h" 21 #include "llvm/ADT/SetVector.h" 22 #include "llvm/ADT/SmallSet.h" 23 24 using namespace mlir; 25 26 namespace { 27 /// Pass to deduce minimal version/extension/capability requirements for a 28 /// spirv::ModuleOp. 29 class UpdateVCEPass final 30 : public OperationPass<UpdateVCEPass, spirv::ModuleOp> { 31 private: 32 void runOnOperation() override; 33 }; 34 } // namespace 35 36 /// Checks that `candidates` extension requirements are possible to be satisfied 37 /// with the given `allowedExtensions` and updates `deducedExtensions` if so. 38 /// Emits errors attaching to the given `op` on failures. 39 /// 40 /// `candidates` is a vector of vector for extension requirements following 41 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 42 /// convention. 43 static LogicalResult checkAndUpdateExtensionRequirements( 44 Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions, 45 const spirv::SPIRVType::ExtensionArrayRefVector &candidates, 46 llvm::SetVector<spirv::Extension> &deducedExtensions) { 47 for (const auto &ors : candidates) { 48 auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) { 49 return allowedExtensions.count(ext); 50 }); 51 52 if (chosen != ors.end()) { 53 deducedExtensions.insert(*chosen); 54 } else { 55 SmallVector<StringRef, 4> extStrings; 56 for (spirv::Extension ext : ors) 57 extStrings.push_back(spirv::stringifyExtension(ext)); 58 59 return op->emitError("'") 60 << op->getName() << "' requires at least one extension in [" 61 << llvm::join(extStrings, ", ") 62 << "] but none allowed in target environment"; 63 } 64 } 65 return success(); 66 } 67 68 /// Checks that `candidates`capability requirements are possible to be satisfied 69 /// with the given `allowedCapabilities` and updates `deducedCapabilities` if 70 /// so. Emits errors attaching to the given `op` on failures. 71 /// 72 /// `candidates` is a vector of vector for capability requirements following 73 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 74 /// convention. 75 static LogicalResult checkAndUpdateCapabilityRequirements( 76 Operation *op, 77 const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities, 78 const spirv::SPIRVType::CapabilityArrayRefVector &candidates, 79 llvm::SetVector<spirv::Capability> &deducedCapabilities) { 80 for (const auto &ors : candidates) { 81 auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) { 82 return allowedCapabilities.count(cap); 83 }); 84 85 if (chosen != ors.end()) { 86 deducedCapabilities.insert(*chosen); 87 } else { 88 SmallVector<StringRef, 4> capStrings; 89 for (spirv::Capability cap : ors) 90 capStrings.push_back(spirv::stringifyCapability(cap)); 91 92 return op->emitError("'") 93 << op->getName() << "' requires at least one capability in [" 94 << llvm::join(capStrings, ", ") 95 << "] but none allowed in target environment"; 96 } 97 } 98 return success(); 99 } 100 101 void UpdateVCEPass::runOnOperation() { 102 spirv::ModuleOp module = getOperation(); 103 104 spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(module); 105 if (!targetEnv) { 106 module.emitError("missing 'spv.target_env' attribute"); 107 return signalPassFailure(); 108 } 109 110 spirv::Version allowedVersion = targetEnv.getVersion(); 111 112 // Build a set for available extensions in the target environment. 113 llvm::SmallSet<spirv::Extension, 4> allowedExtensions; 114 for (spirv::Extension ext : targetEnv.getExtensions()) 115 allowedExtensions.insert(ext); 116 117 // Add extensions implied by the current version. 118 for (spirv::Extension ext : spirv::getImpliedExtensions(allowedVersion)) 119 allowedExtensions.insert(ext); 120 121 // Build a set for available capabilities in the target environment. 122 llvm::SmallSet<spirv::Capability, 8> allowedCapabilities; 123 for (spirv::Capability cap : targetEnv.getCapabilities()) { 124 allowedCapabilities.insert(cap); 125 126 // Add capabilities implied by the current capability. 127 for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap)) 128 allowedCapabilities.insert(c); 129 } 130 131 spirv::Version deducedVersion = spirv::Version::V_1_0; 132 llvm::SetVector<spirv::Extension> deducedExtensions; 133 llvm::SetVector<spirv::Capability> deducedCapabilities; 134 135 // Walk each SPIR-V op to deduce the minimal version/extension/capability 136 // requirements. 137 WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult { 138 // Op min version requirements 139 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) { 140 deducedVersion = std::max(deducedVersion, minVersion.getMinVersion()); 141 if (deducedVersion > allowedVersion) { 142 return op->emitError("'") << op->getName() << "' requires min version " 143 << spirv::stringifyVersion(deducedVersion) 144 << " but target environment allows up to " 145 << spirv::stringifyVersion(allowedVersion); 146 } 147 } 148 149 // Op extension requirements 150 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 151 if (failed(checkAndUpdateExtensionRequirements(op, allowedExtensions, 152 extensions.getExtensions(), 153 deducedExtensions))) 154 return WalkResult::interrupt(); 155 156 // Op capability requirements 157 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 158 if (failed(checkAndUpdateCapabilityRequirements( 159 op, allowedCapabilities, capabilities.getCapabilities(), 160 deducedCapabilities))) 161 return WalkResult::interrupt(); 162 163 SmallVector<Type, 4> valueTypes; 164 valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 165 valueTypes.append(op->result_type_begin(), op->result_type_end()); 166 167 // Special treatment for global variables, whose type requirements are 168 // conveyed by type attributes. 169 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 170 valueTypes.push_back(globalVar.type()); 171 172 // Requirements from values' types 173 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 174 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 175 for (Type valueType : valueTypes) { 176 typeExtensions.clear(); 177 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); 178 if (failed(checkAndUpdateExtensionRequirements( 179 op, allowedExtensions, typeExtensions, deducedExtensions))) 180 return WalkResult::interrupt(); 181 182 typeCapabilities.clear(); 183 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); 184 if (failed(checkAndUpdateCapabilityRequirements( 185 op, allowedCapabilities, typeCapabilities, deducedCapabilities))) 186 return WalkResult::interrupt(); 187 } 188 189 return WalkResult::advance(); 190 }); 191 192 if (walkResult.wasInterrupted()) 193 return signalPassFailure(); 194 195 // TODO(antiagainst): verify that the deduced version is consistent with 196 // SPIR-V ops' maximal version requirements. 197 198 auto triple = spirv::VerCapExtAttr::get( 199 deducedVersion, deducedCapabilities.getArrayRef(), 200 deducedExtensions.getArrayRef(), &getContext()); 201 module.setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple); 202 } 203 204 std::unique_ptr<OpPassBase<spirv::ModuleOp>> 205 mlir::spirv::createUpdateVersionCapabilityExtensionPass() { 206 return std::make_unique<UpdateVCEPass>(); 207 } 208 209 static PassRegistration<UpdateVCEPass> 210 pass("spirv-update-vce", 211 "Deduce and attach minimal (version, capabilities, extensions) " 212 "requirements to spv.module ops"); 213