1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to deduce minimal version/extension/capability
10 // requirements for a spirv::ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/SPIRV/Passes.h"
15 #include "mlir/Dialect/SPIRV/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/TargetAndABI.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/Visitors.h"
21 #include "llvm/ADT/SetVector.h"
22 #include "llvm/ADT/SmallSet.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 /// Pass to deduce minimal version/extension/capability requirements for a
28 /// spirv::ModuleOp.
29 class UpdateVCEPass final
30     : public OperationPass<UpdateVCEPass, spirv::ModuleOp> {
31 private:
32   void runOnOperation() override;
33 };
34 } // namespace
35 
36 /// Checks that `candidates` extension requirements are possible to be satisfied
37 /// with the given `allowedExtensions` and updates `deducedExtensions` if so.
38 /// Emits errors attaching to the given `op` on failures.
39 ///
40 ///  `candidates` is a vector of vector for extension requirements following
41 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
42 /// convention.
43 static LogicalResult checkAndUpdateExtensionRequirements(
44     Operation *op, const llvm::SmallSet<spirv::Extension, 4> &allowedExtensions,
45     const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
46     llvm::SetVector<spirv::Extension> &deducedExtensions) {
47   for (const auto &ors : candidates) {
48     auto chosen = llvm::find_if(ors, [&](spirv::Extension ext) {
49       return allowedExtensions.count(ext);
50     });
51 
52     if (chosen != ors.end()) {
53       deducedExtensions.insert(*chosen);
54     } else {
55       SmallVector<StringRef, 4> extStrings;
56       for (spirv::Extension ext : ors)
57         extStrings.push_back(spirv::stringifyExtension(ext));
58 
59       return op->emitError("'")
60              << op->getName() << "' requires at least one extension in ["
61              << llvm::join(extStrings, ", ")
62              << "] but none allowed in target environment";
63     }
64   }
65   return success();
66 }
67 
68 /// Checks that `candidates`capability requirements are possible to be satisfied
69 /// with the given `allowedCapabilities` and updates `deducedCapabilities` if
70 /// so. Emits errors attaching to the given `op` on failures.
71 ///
72 ///  `candidates` is a vector of vector for capability requirements following
73 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
74 /// convention.
75 static LogicalResult checkAndUpdateCapabilityRequirements(
76     Operation *op,
77     const llvm::SmallSet<spirv::Capability, 8> &allowedCapabilities,
78     const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
79     llvm::SetVector<spirv::Capability> &deducedCapabilities) {
80   for (const auto &ors : candidates) {
81     auto chosen = llvm::find_if(ors, [&](spirv::Capability cap) {
82       return allowedCapabilities.count(cap);
83     });
84 
85     if (chosen != ors.end()) {
86       deducedCapabilities.insert(*chosen);
87     } else {
88       SmallVector<StringRef, 4> capStrings;
89       for (spirv::Capability cap : ors)
90         capStrings.push_back(spirv::stringifyCapability(cap));
91 
92       return op->emitError("'")
93              << op->getName() << "' requires at least one capability in ["
94              << llvm::join(capStrings, ", ")
95              << "] but none allowed in target environment";
96     }
97   }
98   return success();
99 }
100 
101 void UpdateVCEPass::runOnOperation() {
102   spirv::ModuleOp module = getOperation();
103 
104   spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnv(module);
105   if (!targetEnv) {
106     module.emitError("missing 'spv.target_env' attribute");
107     return signalPassFailure();
108   }
109 
110   spirv::Version allowedVersion = targetEnv.getVersion();
111 
112   // Build a set for available extensions in the target environment.
113   llvm::SmallSet<spirv::Extension, 4> allowedExtensions;
114   for (spirv::Extension ext : targetEnv.getExtensions())
115     allowedExtensions.insert(ext);
116 
117   // Add extensions implied by the current version.
118   for (spirv::Extension ext : spirv::getImpliedExtensions(allowedVersion))
119     allowedExtensions.insert(ext);
120 
121   // Build a set for available capabilities in the target environment.
122   llvm::SmallSet<spirv::Capability, 8> allowedCapabilities;
123   for (spirv::Capability cap : targetEnv.getCapabilities()) {
124     allowedCapabilities.insert(cap);
125 
126     // Add capabilities implied by the current capability.
127     for (spirv::Capability c : spirv::getRecursiveImpliedCapabilities(cap))
128       allowedCapabilities.insert(c);
129   }
130 
131   spirv::Version deducedVersion = spirv::Version::V_1_0;
132   llvm::SetVector<spirv::Extension> deducedExtensions;
133   llvm::SetVector<spirv::Capability> deducedCapabilities;
134 
135   // Walk each SPIR-V op to deduce the minimal version/extension/capability
136   // requirements.
137   WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
138     // Op min version requirements
139     if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
140       deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
141       if (deducedVersion > allowedVersion) {
142         return op->emitError("'") << op->getName() << "' requires min version "
143                                   << spirv::stringifyVersion(deducedVersion)
144                                   << " but target environment allows up to "
145                                   << spirv::stringifyVersion(allowedVersion);
146       }
147     }
148 
149     // Op extension requirements
150     if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
151       if (failed(checkAndUpdateExtensionRequirements(op, allowedExtensions,
152                                                      extensions.getExtensions(),
153                                                      deducedExtensions)))
154         return WalkResult::interrupt();
155 
156     // Op capability requirements
157     if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
158       if (failed(checkAndUpdateCapabilityRequirements(
159               op, allowedCapabilities, capabilities.getCapabilities(),
160               deducedCapabilities)))
161         return WalkResult::interrupt();
162 
163     SmallVector<Type, 4> valueTypes;
164     valueTypes.append(op->operand_type_begin(), op->operand_type_end());
165     valueTypes.append(op->result_type_begin(), op->result_type_end());
166 
167     // Special treatment for global variables, whose type requirements are
168     // conveyed by type attributes.
169     if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
170       valueTypes.push_back(globalVar.type());
171 
172     // Requirements from values' types
173     SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
174     SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
175     for (Type valueType : valueTypes) {
176       typeExtensions.clear();
177       valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
178       if (failed(checkAndUpdateExtensionRequirements(
179               op, allowedExtensions, typeExtensions, deducedExtensions)))
180         return WalkResult::interrupt();
181 
182       typeCapabilities.clear();
183       valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
184       if (failed(checkAndUpdateCapabilityRequirements(
185               op, allowedCapabilities, typeCapabilities, deducedCapabilities)))
186         return WalkResult::interrupt();
187     }
188 
189     return WalkResult::advance();
190   });
191 
192   if (walkResult.wasInterrupted())
193     return signalPassFailure();
194 
195   // TODO(antiagainst): verify that the deduced version is consistent with
196   // SPIR-V ops' maximal version requirements.
197 
198   auto triple = spirv::VerCapExtAttr::get(
199       deducedVersion, deducedCapabilities.getArrayRef(),
200       deducedExtensions.getArrayRef(), &getContext());
201   module.setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
202 }
203 
204 std::unique_ptr<OpPassBase<spirv::ModuleOp>>
205 mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
206   return std::make_unique<UpdateVCEPass>();
207 }
208 
209 static PassRegistration<UpdateVCEPass>
210     pass("spirv-update-vce",
211          "Deduce and attach minimal (version, capabilities, extensions) "
212          "requirements to spv.module ops");
213