1 //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to deduce minimal version/extension/capability
10 // requirements for a spirv::ModuleOp.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
18 #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
19 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Visitors.h"
22 #include "llvm/ADT/SetVector.h"
23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 
26 using namespace mlir;
27 
28 namespace {
29 /// Pass to deduce minimal version/extension/capability requirements for a
30 /// spirv::ModuleOp.
31 class UpdateVCEPass final : public SPIRVUpdateVCEBase<UpdateVCEPass> {
32   void runOnOperation() override;
33 };
34 } // namespace
35 
36 /// Checks that `candidates` extension requirements are possible to be satisfied
37 /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
38 /// errors attaching to the given `op` on failures.
39 ///
40 ///  `candidates` is a vector of vector for extension requirements following
41 /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
42 /// convention.
checkAndUpdateExtensionRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::ExtensionArrayRefVector & candidates,SetVector<spirv::Extension> & deducedExtensions)43 static LogicalResult checkAndUpdateExtensionRequirements(
44     Operation *op, const spirv::TargetEnv &targetEnv,
45     const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
46     SetVector<spirv::Extension> &deducedExtensions) {
47   for (const auto &ors : candidates) {
48     if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
49       deducedExtensions.insert(*chosen);
50     } else {
51       SmallVector<StringRef, 4> extStrings;
52       for (spirv::Extension ext : ors)
53         extStrings.push_back(spirv::stringifyExtension(ext));
54 
55       return op->emitError("'")
56              << op->getName() << "' requires at least one extension in ["
57              << llvm::join(extStrings, ", ")
58              << "] but none allowed in target environment";
59     }
60   }
61   return success();
62 }
63 
64 /// Checks that `candidates`capability requirements are possible to be satisfied
65 /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
66 /// errors attaching to the given `op` on failures.
67 ///
68 ///  `candidates` is a vector of vector for capability requirements following
69 /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
70 /// convention.
checkAndUpdateCapabilityRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::CapabilityArrayRefVector & candidates,SetVector<spirv::Capability> & deducedCapabilities)71 static LogicalResult checkAndUpdateCapabilityRequirements(
72     Operation *op, const spirv::TargetEnv &targetEnv,
73     const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
74     SetVector<spirv::Capability> &deducedCapabilities) {
75   for (const auto &ors : candidates) {
76     if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
77       deducedCapabilities.insert(*chosen);
78     } else {
79       SmallVector<StringRef, 4> capStrings;
80       for (spirv::Capability cap : ors)
81         capStrings.push_back(spirv::stringifyCapability(cap));
82 
83       return op->emitError("'")
84              << op->getName() << "' requires at least one capability in ["
85              << llvm::join(capStrings, ", ")
86              << "] but none allowed in target environment";
87     }
88   }
89   return success();
90 }
91 
runOnOperation()92 void UpdateVCEPass::runOnOperation() {
93   spirv::ModuleOp module = getOperation();
94 
95   spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
96   if (!targetAttr) {
97     module.emitError("missing 'spv.target_env' attribute");
98     return signalPassFailure();
99   }
100 
101   spirv::TargetEnv targetEnv(targetAttr);
102   spirv::Version allowedVersion = targetAttr.getVersion();
103 
104   spirv::Version deducedVersion = spirv::Version::V_1_0;
105   SetVector<spirv::Extension> deducedExtensions;
106   SetVector<spirv::Capability> deducedCapabilities;
107 
108   // Walk each SPIR-V op to deduce the minimal version/extension/capability
109   // requirements.
110   WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
111     // Op min version requirements
112     if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
113       Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
114       if (minVersion) {
115         deducedVersion = std::max(deducedVersion, *minVersion);
116         if (deducedVersion > allowedVersion) {
117           return op->emitError("'")
118                  << op->getName() << "' requires min version "
119                  << spirv::stringifyVersion(deducedVersion)
120                  << " but target environment allows up to "
121                  << spirv::stringifyVersion(allowedVersion);
122         }
123       }
124     }
125 
126     // Op extension requirements
127     if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
128       if (failed(checkAndUpdateExtensionRequirements(
129               op, targetEnv, extensions.getExtensions(), deducedExtensions)))
130         return WalkResult::interrupt();
131 
132     // Op capability requirements
133     if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
134       if (failed(checkAndUpdateCapabilityRequirements(
135               op, targetEnv, capabilities.getCapabilities(),
136               deducedCapabilities)))
137         return WalkResult::interrupt();
138 
139     SmallVector<Type, 4> valueTypes;
140     valueTypes.append(op->operand_type_begin(), op->operand_type_end());
141     valueTypes.append(op->result_type_begin(), op->result_type_end());
142 
143     // Special treatment for global variables, whose type requirements are
144     // conveyed by type attributes.
145     if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
146       valueTypes.push_back(globalVar.type());
147 
148     // Requirements from values' types
149     SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
150     SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
151     for (Type valueType : valueTypes) {
152       typeExtensions.clear();
153       valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
154       if (failed(checkAndUpdateExtensionRequirements(
155               op, targetEnv, typeExtensions, deducedExtensions)))
156         return WalkResult::interrupt();
157 
158       typeCapabilities.clear();
159       valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
160       if (failed(checkAndUpdateCapabilityRequirements(
161               op, targetEnv, typeCapabilities, deducedCapabilities)))
162         return WalkResult::interrupt();
163     }
164 
165     return WalkResult::advance();
166   });
167 
168   if (walkResult.wasInterrupted())
169     return signalPassFailure();
170 
171   // TODO: verify that the deduced version is consistent with
172   // SPIR-V ops' maximal version requirements.
173 
174   auto triple = spirv::VerCapExtAttr::get(
175       deducedVersion, deducedCapabilities.getArrayRef(),
176       deducedExtensions.getArrayRef(), &getContext());
177   module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
178 }
179 
180 std::unique_ptr<OperationPass<spirv::ModuleOp>>
createUpdateVersionCapabilityExtensionPass()181 mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
182   return std::make_unique<UpdateVCEPass>();
183 }
184