1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements a pass to deduce minimal version/extension/capability 10 // requirements for a spirv::ModuleOp. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "PassDetail.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 18 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" 19 #include "mlir/Dialect/SPIRV/Transforms/Passes.h" 20 #include "mlir/IR/Builders.h" 21 #include "mlir/IR/Visitors.h" 22 #include "llvm/ADT/SetVector.h" 23 #include "llvm/ADT/SmallSet.h" 24 #include "llvm/ADT/StringExtras.h" 25 26 using namespace mlir; 27 28 namespace { 29 /// Pass to deduce minimal version/extension/capability requirements for a 30 /// spirv::ModuleOp. 31 class UpdateVCEPass final : public SPIRVUpdateVCEBase<UpdateVCEPass> { 32 void runOnOperation() override; 33 }; 34 } // namespace 35 36 /// Checks that `candidates` extension requirements are possible to be satisfied 37 /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits 38 /// errors attaching to the given `op` on failures. 39 /// 40 /// `candidates` is a vector of vector for extension requirements following 41 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D)) 42 /// convention. 43 static LogicalResult checkAndUpdateExtensionRequirements( 44 Operation *op, const spirv::TargetEnv &targetEnv, 45 const spirv::SPIRVType::ExtensionArrayRefVector &candidates, 46 SetVector<spirv::Extension> &deducedExtensions) { 47 for (const auto &ors : candidates) { 48 if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) { 49 deducedExtensions.insert(*chosen); 50 } else { 51 SmallVector<StringRef, 4> extStrings; 52 for (spirv::Extension ext : ors) 53 extStrings.push_back(spirv::stringifyExtension(ext)); 54 55 return op->emitError("'") 56 << op->getName() << "' requires at least one extension in [" 57 << llvm::join(extStrings, ", ") 58 << "] but none allowed in target environment"; 59 } 60 } 61 return success(); 62 } 63 64 /// Checks that `candidates`capability requirements are possible to be satisfied 65 /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits 66 /// errors attaching to the given `op` on failures. 67 /// 68 /// `candidates` is a vector of vector for capability requirements following 69 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D)) 70 /// convention. 71 static LogicalResult checkAndUpdateCapabilityRequirements( 72 Operation *op, const spirv::TargetEnv &targetEnv, 73 const spirv::SPIRVType::CapabilityArrayRefVector &candidates, 74 SetVector<spirv::Capability> &deducedCapabilities) { 75 for (const auto &ors : candidates) { 76 if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) { 77 deducedCapabilities.insert(*chosen); 78 } else { 79 SmallVector<StringRef, 4> capStrings; 80 for (spirv::Capability cap : ors) 81 capStrings.push_back(spirv::stringifyCapability(cap)); 82 83 return op->emitError("'") 84 << op->getName() << "' requires at least one capability in [" 85 << llvm::join(capStrings, ", ") 86 << "] but none allowed in target environment"; 87 } 88 } 89 return success(); 90 } 91 92 void UpdateVCEPass::runOnOperation() { 93 spirv::ModuleOp module = getOperation(); 94 95 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module); 96 if (!targetAttr) { 97 module.emitError("missing 'spv.target_env' attribute"); 98 return signalPassFailure(); 99 } 100 101 spirv::TargetEnv targetEnv(targetAttr); 102 spirv::Version allowedVersion = targetAttr.getVersion(); 103 104 spirv::Version deducedVersion = spirv::Version::V_1_0; 105 SetVector<spirv::Extension> deducedExtensions; 106 SetVector<spirv::Capability> deducedCapabilities; 107 108 // Walk each SPIR-V op to deduce the minimal version/extension/capability 109 // requirements. 110 WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult { 111 // Op min version requirements 112 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) { 113 deducedVersion = std::max(deducedVersion, minVersion.getMinVersion()); 114 if (deducedVersion > allowedVersion) { 115 return op->emitError("'") << op->getName() << "' requires min version " 116 << spirv::stringifyVersion(deducedVersion) 117 << " but target environment allows up to " 118 << spirv::stringifyVersion(allowedVersion); 119 } 120 } 121 122 // Op extension requirements 123 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op)) 124 if (failed(checkAndUpdateExtensionRequirements( 125 op, targetEnv, extensions.getExtensions(), deducedExtensions))) 126 return WalkResult::interrupt(); 127 128 // Op capability requirements 129 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op)) 130 if (failed(checkAndUpdateCapabilityRequirements( 131 op, targetEnv, capabilities.getCapabilities(), 132 deducedCapabilities))) 133 return WalkResult::interrupt(); 134 135 SmallVector<Type, 4> valueTypes; 136 valueTypes.append(op->operand_type_begin(), op->operand_type_end()); 137 valueTypes.append(op->result_type_begin(), op->result_type_end()); 138 139 // Special treatment for global variables, whose type requirements are 140 // conveyed by type attributes. 141 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op)) 142 valueTypes.push_back(globalVar.type()); 143 144 // Requirements from values' types 145 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions; 146 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities; 147 for (Type valueType : valueTypes) { 148 typeExtensions.clear(); 149 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions); 150 if (failed(checkAndUpdateExtensionRequirements( 151 op, targetEnv, typeExtensions, deducedExtensions))) 152 return WalkResult::interrupt(); 153 154 typeCapabilities.clear(); 155 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities); 156 if (failed(checkAndUpdateCapabilityRequirements( 157 op, targetEnv, typeCapabilities, deducedCapabilities))) 158 return WalkResult::interrupt(); 159 } 160 161 return WalkResult::advance(); 162 }); 163 164 if (walkResult.wasInterrupted()) 165 return signalPassFailure(); 166 167 // TODO: verify that the deduced version is consistent with 168 // SPIR-V ops' maximal version requirements. 169 170 auto triple = spirv::VerCapExtAttr::get( 171 deducedVersion, deducedCapabilities.getArrayRef(), 172 deducedExtensions.getArrayRef(), &getContext()); 173 module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple); 174 } 175 176 std::unique_ptr<OperationPass<spirv::ModuleOp>> 177 mlir::spirv::createUpdateVersionCapabilityExtensionPass() { 178 return std::make_unique<UpdateVCEPass>(); 179 } 180