1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to deduce minimal version/extension/capability 10 // requirements for a spirv::ModuleOp. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/SPIRV/Passes.h" 16 #include "mlir/Dialect/SPIRV/SPIRVDialect.h" 17 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 18 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 19 #include "mlir/Dialect/SPIRV/TargetAndABI.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/Visitors.h" 22 #include "llvm/ADT/SetVector.h" 23 #include "llvm/ADT/SmallSet.h" 24 25 using namespace mlir; 26 27 namespace { 28 /// Pass to deduce minimal version/extension/capability requirements for a 29 /// spirv::ModuleOp. 30 class UpdateVCEPass final : public SPIRVUpdateVCEBase<UpdateVCEPass> { 31 void runOnOperation() override; 32 }; 33 } // namespace 34 35 /// Checks that `candidates` extension requirements are possible to be satisfied 36 /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits 37 /// errors attaching to the given `op` on failures. 38 /// 39 /// `candidates` is a vector of vector for extension requirements following 40 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 41 /// convention. 42 static LogicalResult checkAndUpdateExtensionRequirements( 43 Operation *op, const spirv::TargetEnv &targetEnv, 44 const spirv::SPIRVType::ExtensionArrayRefVector &candidates, 45 llvm::SetVector<spirv::Extension> &deducedExtensions) { 46 for (const auto &ors : candidates) { 47 if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) { 48 deducedExtensions.insert(*chosen); 49 } else { 50 SmallVector<StringRef, 4> extStrings; 51 for (spirv::Extension ext : ors) 52 extStrings.push_back(spirv::stringifyExtension(ext)); 53 54 return op->emitError("'") 55 << op->getName() << "' requires at least one extension in [" 56 << llvm::join(extStrings, ", ") 57 << "] but none allowed in target environment"; 58 } 59 } 60 return success(); 61 } 62 63 /// Checks that `candidates`capability requirements are possible to be satisfied 64 /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits 65 /// errors attaching to the given `op` on failures. 66 /// 67 /// `candidates` is a vector of vector for capability requirements following 68 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 69 /// convention. 70 static LogicalResult checkAndUpdateCapabilityRequirements( 71 Operation *op, const spirv::TargetEnv &targetEnv, 72 const spirv::SPIRVType::CapabilityArrayRefVector &candidates, 73 llvm::SetVector<spirv::Capability> &deducedCapabilities) { 74 for (const auto &ors : candidates) { 75 if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) { 76 deducedCapabilities.insert(*chosen); 77 } else { 78 SmallVector<StringRef, 4> capStrings; 79 for (spirv::Capability cap : ors) 80 capStrings.push_back(spirv::stringifyCapability(cap)); 81 82 return op->emitError("'") 83 << op->getName() << "' requires at least one capability in [" 84 << llvm::join(capStrings, ", ") 85 << "] but none allowed in target environment"; 86 } 87 } 88 return success(); 89 } 90 91 void UpdateVCEPass::runOnOperation() { 92 spirv::ModuleOp module = getOperation(); 93 94 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module); 95 if (!targetAttr) { 96 module.emitError("missing 'spv.target_env' attribute"); 97 return signalPassFailure(); 98 } 99 100 spirv::TargetEnv targetEnv(targetAttr); 101 spirv::Version allowedVersion = targetAttr.getVersion(); 102 103 spirv::Version deducedVersion = spirv::Version::V_1_0; 104 llvm::SetVector<spirv::Extension> deducedExtensions; 105 llvm::SetVector<spirv::Capability> deducedCapabilities; 106 107 // Walk each SPIR-V op to deduce the minimal version/extension/capability 108 // requirements. 109 WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult { 110 // Op min version requirements 111 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) { 112 deducedVersion = std::max(deducedVersion, minVersion.getMinVersion()); 113 if (deducedVersion > allowedVersion) { 114 return op->emitError("'") << op->getName() << "' requires min version " 115 << spirv::stringifyVersion(deducedVersion) 116 << " but target environment allows up to " 117 << spirv::stringifyVersion(allowedVersion); 118 } 119 } 120 121 // Op extension requirements 122 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 123 if (failed(checkAndUpdateExtensionRequirements( 124 op, targetEnv, extensions.getExtensions(), deducedExtensions))) 125 return WalkResult::interrupt(); 126 127 // Op capability requirements 128 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 129 if (failed(checkAndUpdateCapabilityRequirements( 130 op, targetEnv, capabilities.getCapabilities(), 131 deducedCapabilities))) 132 return WalkResult::interrupt(); 133 134 SmallVector<Type, 4> valueTypes; 135 valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 136 valueTypes.append(op->result_type_begin(), op->result_type_end()); 137 138 // Special treatment for global variables, whose type requirements are 139 // conveyed by type attributes. 140 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 141 valueTypes.push_back(globalVar.type()); 142 143 // Requirements from values' types 144 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 145 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 146 for (Type valueType : valueTypes) { 147 typeExtensions.clear(); 148 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); 149 if (failed(checkAndUpdateExtensionRequirements( 150 op, targetEnv, typeExtensions, deducedExtensions))) 151 return WalkResult::interrupt(); 152 153 typeCapabilities.clear(); 154 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); 155 if (failed(checkAndUpdateCapabilityRequirements( 156 op, targetEnv, typeCapabilities, deducedCapabilities))) 157 return WalkResult::interrupt(); 158 } 159 160 return WalkResult::advance(); 161 }); 162 163 if (walkResult.wasInterrupted()) 164 return signalPassFailure(); 165 166 // TODO: verify that the deduced version is consistent with 167 // SPIR-V ops' maximal version requirements. 168 169 auto triple = spirv::VerCapExtAttr::get( 170 deducedVersion, deducedCapabilities.getArrayRef(), 171 deducedExtensions.getArrayRef(), &getContext()); 172 module.setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple); 173 } 174 175 std::unique_ptr<OperationPass<spirv::ModuleOp>> 176 mlir::spirv::createUpdateVersionCapabilityExtensionPass() { 177 return std::make_unique<UpdateVCEPass>(); 178 } 179