1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to deduce minimal version/extension/capability
10 // requirements for a spirv::ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
19 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Visitors.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 /// Pass to deduce minimal version/extension/capability requirements for a
30 /// spirv::ModuleOp.
31 class UpdateVCEPass final : public SPIRVUpdateVCEBase<UpdateVCEPass> {
32   void runOnOperation() override;
33 };
34 } // namespace
35 
36 /// Checks that `candidates` extension requirements are possible to be satisfied
37 /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
38 /// errors attaching to the given `op` on failures.
39 ///
40 ///  `candidates` is a vector of vector for extension requirements following
41 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
42 /// convention.
43 static LogicalResult checkAndUpdateExtensionRequirements(
44     Operation *op, const spirv::TargetEnv &targetEnv,
45     const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
46     SetVector<spirv::Extension> &deducedExtensions) {
47   for (const auto &ors : candidates) {
48     if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
49       deducedExtensions.insert(*chosen);
50     } else {
51       SmallVector<StringRef, 4> extStrings;
52       for (spirv::Extension ext : ors)
53         extStrings.push_back(spirv::stringifyExtension(ext));
54 
55       return op->emitError("'")
56              << op->getName() << "' requires at least one extension in ["
57              << llvm::join(extStrings, ", ")
58              << "] but none allowed in target environment";
59     }
60   }
61   return success();
62 }
63 
64 /// Checks that `candidates`capability requirements are possible to be satisfied
65 /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
66 /// errors attaching to the given `op` on failures.
67 ///
68 ///  `candidates` is a vector of vector for capability requirements following
69 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
70 /// convention.
71 static LogicalResult checkAndUpdateCapabilityRequirements(
72     Operation *op, const spirv::TargetEnv &targetEnv,
73     const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
74     SetVector<spirv::Capability> &deducedCapabilities) {
75   for (const auto &ors : candidates) {
76     if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
77       deducedCapabilities.insert(*chosen);
78     } else {
79       SmallVector<StringRef, 4> capStrings;
80       for (spirv::Capability cap : ors)
81         capStrings.push_back(spirv::stringifyCapability(cap));
82 
83       return op->emitError("'")
84              << op->getName() << "' requires at least one capability in ["
85              << llvm::join(capStrings, ", ")
86              << "] but none allowed in target environment";
87     }
88   }
89   return success();
90 }
91 
92 void UpdateVCEPass::runOnOperation() {
93   spirv::ModuleOp module = getOperation();
94 
95   spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
96   if (!targetAttr) {
97     module.emitError("missing 'spv.target_env' attribute");
98     return signalPassFailure();
99   }
100 
101   spirv::TargetEnv targetEnv(targetAttr);
102   spirv::Version allowedVersion = targetAttr.getVersion();
103 
104   spirv::Version deducedVersion = spirv::Version::V_1_0;
105   SetVector<spirv::Extension> deducedExtensions;
106   SetVector<spirv::Capability> deducedCapabilities;
107 
108   // Walk each SPIR-V op to deduce the minimal version/extension/capability
109   // requirements.
110   WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
111     // Op min version requirements
112     if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
113       deducedVersion = std::max(deducedVersion, minVersion.getMinVersion());
114       if (deducedVersion > allowedVersion) {
115         return op->emitError("'") << op->getName() << "' requires min version "
116                                   << spirv::stringifyVersion(deducedVersion)
117                                   << " but target environment allows up to "
118                                   << spirv::stringifyVersion(allowedVersion);
119       }
120     }
121 
122     // Op extension requirements
123     if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
124       if (failed(checkAndUpdateExtensionRequirements(
125               op, targetEnv, extensions.getExtensions(), deducedExtensions)))
126         return WalkResult::interrupt();
127 
128     // Op capability requirements
129     if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
130       if (failed(checkAndUpdateCapabilityRequirements(
131               op, targetEnv, capabilities.getCapabilities(),
132               deducedCapabilities)))
133         return WalkResult::interrupt();
134 
135     SmallVector<Type, 4> valueTypes;
136     valueTypes.append(op->operand_type_begin(), op->operand_type_end());
137     valueTypes.append(op->result_type_begin(), op->result_type_end());
138 
139     // Special treatment for global variables, whose type requirements are
140     // conveyed by type attributes.
141     if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
142       valueTypes.push_back(globalVar.type());
143 
144     // Requirements from values' types
145     SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
146     SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
147     for (Type valueType : valueTypes) {
148       typeExtensions.clear();
149       valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
150       if (failed(checkAndUpdateExtensionRequirements(
151               op, targetEnv, typeExtensions, deducedExtensions)))
152         return WalkResult::interrupt();
153 
154       typeCapabilities.clear();
155       valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
156       if (failed(checkAndUpdateCapabilityRequirements(
157               op, targetEnv, typeCapabilities, deducedCapabilities)))
158         return WalkResult::interrupt();
159     }
160 
161     return WalkResult::advance();
162   });
163 
164   if (walkResult.wasInterrupted())
165     return signalPassFailure();
166 
167   // TODO: verify that the deduced version is consistent with
168   // SPIR-V ops' maximal version requirements.
169 
170   auto triple = spirv::VerCapExtAttr::get(
171       deducedVersion, deducedCapabilities.getArrayRef(),
172       deducedExtensions.getArrayRef(), &getContext());
173   module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
174 }
175 
176 std::unique_ptr<OperationPass<spirv::ModuleOp>>
177 mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
178   return std::make_unique<UpdateVCEPass>();
179 }
180