1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to deduce minimal version/extension/capability
10 // requirements for a spirv::ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "PassDetail.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
19 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Visitors.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/StringExtras.h"
25
26 using namespace mlir;
27
28 namespace {
29 /// Pass to deduce minimal version/extension/capability requirements for a
30 /// spirv::ModuleOp.
31 class UpdateVCEPass final : public SPIRVUpdateVCEBase<UpdateVCEPass> {
32 void runOnOperation() override;
33 };
34 } // namespace
35
36 /// Checks that `candidates` extension requirements are possible to be satisfied
37 /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
38 /// errors attaching to the given `op` on failures.
39 ///
40 /// `candidates` is a vector of vector for extension requirements following
41 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
42 /// convention.
checkAndUpdateExtensionRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::ExtensionArrayRefVector & candidates,SetVector<spirv::Extension> & deducedExtensions)43 static LogicalResult checkAndUpdateExtensionRequirements(
44 Operation *op, const spirv::TargetEnv &targetEnv,
45 const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
46 SetVector<spirv::Extension> &deducedExtensions) {
47 for (const auto &ors : candidates) {
48 if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
49 deducedExtensions.insert(*chosen);
50 } else {
51 SmallVector<StringRef, 4> extStrings;
52 for (spirv::Extension ext : ors)
53 extStrings.push_back(spirv::stringifyExtension(ext));
54
55 return op->emitError("'")
56 << op->getName() << "' requires at least one extension in ["
57 << llvm::join(extStrings, ", ")
58 << "] but none allowed in target environment";
59 }
60 }
61 return success();
62 }
63
64 /// Checks that `candidates`capability requirements are possible to be satisfied
65 /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
66 /// errors attaching to the given `op` on failures.
67 ///
68 /// `candidates` is a vector of vector for capability requirements following
69 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
70 /// convention.
checkAndUpdateCapabilityRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::CapabilityArrayRefVector & candidates,SetVector<spirv::Capability> & deducedCapabilities)71 static LogicalResult checkAndUpdateCapabilityRequirements(
72 Operation *op, const spirv::TargetEnv &targetEnv,
73 const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
74 SetVector<spirv::Capability> &deducedCapabilities) {
75 for (const auto &ors : candidates) {
76 if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
77 deducedCapabilities.insert(*chosen);
78 } else {
79 SmallVector<StringRef, 4> capStrings;
80 for (spirv::Capability cap : ors)
81 capStrings.push_back(spirv::stringifyCapability(cap));
82
83 return op->emitError("'")
84 << op->getName() << "' requires at least one capability in ["
85 << llvm::join(capStrings, ", ")
86 << "] but none allowed in target environment";
87 }
88 }
89 return success();
90 }
91
runOnOperation()92 void UpdateVCEPass::runOnOperation() {
93 spirv::ModuleOp module = getOperation();
94
95 spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
96 if (!targetAttr) {
97 module.emitError("missing 'spv.target_env' attribute");
98 return signalPassFailure();
99 }
100
101 spirv::TargetEnv targetEnv(targetAttr);
102 spirv::Version allowedVersion = targetAttr.getVersion();
103
104 spirv::Version deducedVersion = spirv::Version::V_1_0;
105 SetVector<spirv::Extension> deducedExtensions;
106 SetVector<spirv::Capability> deducedCapabilities;
107
108 // Walk each SPIR-V op to deduce the minimal version/extension/capability
109 // requirements.
110 WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
111 // Op min version requirements
112 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
113 Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
114 if (minVersion) {
115 deducedVersion = std::max(deducedVersion, *minVersion);
116 if (deducedVersion > allowedVersion) {
117 return op->emitError("'")
118 << op->getName() << "' requires min version "
119 << spirv::stringifyVersion(deducedVersion)
120 << " but target environment allows up to "
121 << spirv::stringifyVersion(allowedVersion);
122 }
123 }
124 }
125
126 // Op extension requirements
127 if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
128 if (failed(checkAndUpdateExtensionRequirements(
129 op, targetEnv, extensions.getExtensions(), deducedExtensions)))
130 return WalkResult::interrupt();
131
132 // Op capability requirements
133 if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
134 if (failed(checkAndUpdateCapabilityRequirements(
135 op, targetEnv, capabilities.getCapabilities(),
136 deducedCapabilities)))
137 return WalkResult::interrupt();
138
139 SmallVector<Type, 4> valueTypes;
140 valueTypes.append(op->operand_type_begin(), op->operand_type_end());
141 valueTypes.append(op->result_type_begin(), op->result_type_end());
142
143 // Special treatment for global variables, whose type requirements are
144 // conveyed by type attributes.
145 if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
146 valueTypes.push_back(globalVar.type());
147
148 // Requirements from values' types
149 SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
150 SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
151 for (Type valueType : valueTypes) {
152 typeExtensions.clear();
153 valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
154 if (failed(checkAndUpdateExtensionRequirements(
155 op, targetEnv, typeExtensions, deducedExtensions)))
156 return WalkResult::interrupt();
157
158 typeCapabilities.clear();
159 valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
160 if (failed(checkAndUpdateCapabilityRequirements(
161 op, targetEnv, typeCapabilities, deducedCapabilities)))
162 return WalkResult::interrupt();
163 }
164
165 return WalkResult::advance();
166 });
167
168 if (walkResult.wasInterrupted())
169 return signalPassFailure();
170
171 // TODO: verify that the deduced version is consistent with
172 // SPIR-V ops' maximal version requirements.
173
174 auto triple = spirv::VerCapExtAttr::get(
175 deducedVersion, deducedCapabilities.getArrayRef(),
176 deducedExtensions.getArrayRef(), &getContext());
177 module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
178 }
179
180 std::unique_ptr<OperationPass<spirv::ModuleOp>>
createUpdateVersionCapabilityExtensionPass()181 mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
182 return std::make_unique<UpdateVCEPass>();
183 }
184