19414db10SLei Zhang //===- DeduceVersionExtensionCapabilityPass.cpp ---------------------------===//
29414db10SLei Zhang //
39414db10SLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49414db10SLei Zhang // See https://llvm.org/LICENSE.txt for license information.
59414db10SLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69414db10SLei Zhang //
79414db10SLei Zhang //===----------------------------------------------------------------------===//
89414db10SLei Zhang //
99414db10SLei Zhang // This file implements a pass to deduce minimal version/extension/capability
109414db10SLei Zhang // requirements for a spirv::ModuleOp.
119414db10SLei Zhang //
129414db10SLei Zhang //===----------------------------------------------------------------------===//
139414db10SLei Zhang
141834ad4aSRiver Riddle #include "PassDetail.h"
1501178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1601178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1701178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1801178654SLei Zhang #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
1901178654SLei Zhang #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
209414db10SLei Zhang #include "mlir/IR/Builders.h"
219414db10SLei Zhang #include "mlir/IR/Visitors.h"
229414db10SLei Zhang #include "llvm/ADT/SetVector.h"
239414db10SLei Zhang #include "llvm/ADT/SmallSet.h"
24297a5b7cSNico Weber #include "llvm/ADT/StringExtras.h"
259414db10SLei Zhang
269414db10SLei Zhang using namespace mlir;
279414db10SLei Zhang
289414db10SLei Zhang namespace {
299414db10SLei Zhang /// Pass to deduce minimal version/extension/capability requirements for a
309414db10SLei Zhang /// spirv::ModuleOp.
311834ad4aSRiver Riddle class UpdateVCEPass final : public SPIRVUpdateVCEBase<UpdateVCEPass> {
329414db10SLei Zhang void runOnOperation() override;
339414db10SLei Zhang };
349414db10SLei Zhang } // namespace
359414db10SLei Zhang
36e5c85a5aSLei Zhang /// Checks that `candidates` extension requirements are possible to be satisfied
3758df5e6dSLei Zhang /// with the given `targetEnv` and updates `deducedExtensions` if so. Emits
3858df5e6dSLei Zhang /// errors attaching to the given `op` on failures.
39e5c85a5aSLei Zhang ///
40e5c85a5aSLei Zhang /// `candidates` is a vector of vector for extension requirements following
41e5c85a5aSLei Zhang /// ((Extension::A OR Extension::B) AND (Extension::C OR Extension::D))
42e5c85a5aSLei Zhang /// convention.
checkAndUpdateExtensionRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::ExtensionArrayRefVector & candidates,SetVector<spirv::Extension> & deducedExtensions)43e5c85a5aSLei Zhang static LogicalResult checkAndUpdateExtensionRequirements(
4458df5e6dSLei Zhang Operation *op, const spirv::TargetEnv &targetEnv,
45e5c85a5aSLei Zhang const spirv::SPIRVType::ExtensionArrayRefVector &candidates,
464efb7754SRiver Riddle SetVector<spirv::Extension> &deducedExtensions) {
47e5c85a5aSLei Zhang for (const auto &ors : candidates) {
4858df5e6dSLei Zhang if (Optional<spirv::Extension> chosen = targetEnv.allows(ors)) {
49e5c85a5aSLei Zhang deducedExtensions.insert(*chosen);
50e5c85a5aSLei Zhang } else {
51e5c85a5aSLei Zhang SmallVector<StringRef, 4> extStrings;
52e5c85a5aSLei Zhang for (spirv::Extension ext : ors)
53e5c85a5aSLei Zhang extStrings.push_back(spirv::stringifyExtension(ext));
54e5c85a5aSLei Zhang
55e5c85a5aSLei Zhang return op->emitError("'")
56e5c85a5aSLei Zhang << op->getName() << "' requires at least one extension in ["
57e5c85a5aSLei Zhang << llvm::join(extStrings, ", ")
58e5c85a5aSLei Zhang << "] but none allowed in target environment";
59e5c85a5aSLei Zhang }
60e5c85a5aSLei Zhang }
61e5c85a5aSLei Zhang return success();
62e5c85a5aSLei Zhang }
63e5c85a5aSLei Zhang
64e5c85a5aSLei Zhang /// Checks that `candidates`capability requirements are possible to be satisfied
6558df5e6dSLei Zhang /// with the given `targetEnv` and updates `deducedCapabilities` if so. Emits
6658df5e6dSLei Zhang /// errors attaching to the given `op` on failures.
67e5c85a5aSLei Zhang ///
68e5c85a5aSLei Zhang /// `candidates` is a vector of vector for capability requirements following
69e5c85a5aSLei Zhang /// ((Capability::A OR Capability::B) AND (Capability::C OR Capability::D))
70e5c85a5aSLei Zhang /// convention.
checkAndUpdateCapabilityRequirements(Operation * op,const spirv::TargetEnv & targetEnv,const spirv::SPIRVType::CapabilityArrayRefVector & candidates,SetVector<spirv::Capability> & deducedCapabilities)71e5c85a5aSLei Zhang static LogicalResult checkAndUpdateCapabilityRequirements(
7258df5e6dSLei Zhang Operation *op, const spirv::TargetEnv &targetEnv,
73e5c85a5aSLei Zhang const spirv::SPIRVType::CapabilityArrayRefVector &candidates,
744efb7754SRiver Riddle SetVector<spirv::Capability> &deducedCapabilities) {
75e5c85a5aSLei Zhang for (const auto &ors : candidates) {
7658df5e6dSLei Zhang if (Optional<spirv::Capability> chosen = targetEnv.allows(ors)) {
77e5c85a5aSLei Zhang deducedCapabilities.insert(*chosen);
78e5c85a5aSLei Zhang } else {
79e5c85a5aSLei Zhang SmallVector<StringRef, 4> capStrings;
80e5c85a5aSLei Zhang for (spirv::Capability cap : ors)
81e5c85a5aSLei Zhang capStrings.push_back(spirv::stringifyCapability(cap));
82e5c85a5aSLei Zhang
83e5c85a5aSLei Zhang return op->emitError("'")
84e5c85a5aSLei Zhang << op->getName() << "' requires at least one capability in ["
85e5c85a5aSLei Zhang << llvm::join(capStrings, ", ")
86e5c85a5aSLei Zhang << "] but none allowed in target environment";
87e5c85a5aSLei Zhang }
88e5c85a5aSLei Zhang }
89e5c85a5aSLei Zhang return success();
90e5c85a5aSLei Zhang }
91e5c85a5aSLei Zhang
runOnOperation()929414db10SLei Zhang void UpdateVCEPass::runOnOperation() {
939414db10SLei Zhang spirv::ModuleOp module = getOperation();
949414db10SLei Zhang
9558df5e6dSLei Zhang spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnv(module);
9658df5e6dSLei Zhang if (!targetAttr) {
979414db10SLei Zhang module.emitError("missing 'spv.target_env' attribute");
989414db10SLei Zhang return signalPassFailure();
999414db10SLei Zhang }
1009414db10SLei Zhang
10158df5e6dSLei Zhang spirv::TargetEnv targetEnv(targetAttr);
10258df5e6dSLei Zhang spirv::Version allowedVersion = targetAttr.getVersion();
1039414db10SLei Zhang
1049414db10SLei Zhang spirv::Version deducedVersion = spirv::Version::V_1_0;
1054efb7754SRiver Riddle SetVector<spirv::Extension> deducedExtensions;
1064efb7754SRiver Riddle SetVector<spirv::Capability> deducedCapabilities;
1079414db10SLei Zhang
1089414db10SLei Zhang // Walk each SPIR-V op to deduce the minimal version/extension/capability
1099414db10SLei Zhang // requirements.
1109414db10SLei Zhang WalkResult walkResult = module.walk([&](Operation *op) -> WalkResult {
111e5c85a5aSLei Zhang // Op min version requirements
112*cb395f66SLei Zhang if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
113*cb395f66SLei Zhang Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
114*cb395f66SLei Zhang if (minVersion) {
115*cb395f66SLei Zhang deducedVersion = std::max(deducedVersion, *minVersion);
1169414db10SLei Zhang if (deducedVersion > allowedVersion) {
117*cb395f66SLei Zhang return op->emitError("'")
118*cb395f66SLei Zhang << op->getName() << "' requires min version "
1199414db10SLei Zhang << spirv::stringifyVersion(deducedVersion)
1209414db10SLei Zhang << " but target environment allows up to "
1219414db10SLei Zhang << spirv::stringifyVersion(allowedVersion);
1229414db10SLei Zhang }
1239414db10SLei Zhang }
124*cb395f66SLei Zhang }
1259414db10SLei Zhang
126e5c85a5aSLei Zhang // Op extension requirements
127e5c85a5aSLei Zhang if (auto extensions = dyn_cast<spirv::QueryExtensionInterface>(op))
12858df5e6dSLei Zhang if (failed(checkAndUpdateExtensionRequirements(
12958df5e6dSLei Zhang op, targetEnv, extensions.getExtensions(), deducedExtensions)))
130e5c85a5aSLei Zhang return WalkResult::interrupt();
1319414db10SLei Zhang
132e5c85a5aSLei Zhang // Op capability requirements
133e5c85a5aSLei Zhang if (auto capabilities = dyn_cast<spirv::QueryCapabilityInterface>(op))
134e5c85a5aSLei Zhang if (failed(checkAndUpdateCapabilityRequirements(
13558df5e6dSLei Zhang op, targetEnv, capabilities.getCapabilities(),
136e5c85a5aSLei Zhang deducedCapabilities)))
137e5c85a5aSLei Zhang return WalkResult::interrupt();
1389414db10SLei Zhang
139e5c85a5aSLei Zhang SmallVector<Type, 4> valueTypes;
140e5c85a5aSLei Zhang valueTypes.append(op->operand_type_begin(), op->operand_type_end());
141e5c85a5aSLei Zhang valueTypes.append(op->result_type_begin(), op->result_type_end());
1429414db10SLei Zhang
143e5c85a5aSLei Zhang // Special treatment for global variables, whose type requirements are
144e5c85a5aSLei Zhang // conveyed by type attributes.
145e5c85a5aSLei Zhang if (auto globalVar = dyn_cast<spirv::GlobalVariableOp>(op))
146e5c85a5aSLei Zhang valueTypes.push_back(globalVar.type());
1479414db10SLei Zhang
148e5c85a5aSLei Zhang // Requirements from values' types
149e5c85a5aSLei Zhang SmallVector<ArrayRef<spirv::Extension>, 4> typeExtensions;
150e5c85a5aSLei Zhang SmallVector<ArrayRef<spirv::Capability>, 8> typeCapabilities;
151e5c85a5aSLei Zhang for (Type valueType : valueTypes) {
152e5c85a5aSLei Zhang typeExtensions.clear();
153e5c85a5aSLei Zhang valueType.cast<spirv::SPIRVType>().getExtensions(typeExtensions);
154e5c85a5aSLei Zhang if (failed(checkAndUpdateExtensionRequirements(
15558df5e6dSLei Zhang op, targetEnv, typeExtensions, deducedExtensions)))
156e5c85a5aSLei Zhang return WalkResult::interrupt();
1579414db10SLei Zhang
158e5c85a5aSLei Zhang typeCapabilities.clear();
159e5c85a5aSLei Zhang valueType.cast<spirv::SPIRVType>().getCapabilities(typeCapabilities);
160e5c85a5aSLei Zhang if (failed(checkAndUpdateCapabilityRequirements(
16158df5e6dSLei Zhang op, targetEnv, typeCapabilities, deducedCapabilities)))
162e5c85a5aSLei Zhang return WalkResult::interrupt();
1639414db10SLei Zhang }
1649414db10SLei Zhang
1659414db10SLei Zhang return WalkResult::advance();
1669414db10SLei Zhang });
1679414db10SLei Zhang
1689414db10SLei Zhang if (walkResult.wasInterrupted())
1699414db10SLei Zhang return signalPassFailure();
1709414db10SLei Zhang
1719db53a18SRiver Riddle // TODO: verify that the deduced version is consistent with
1729414db10SLei Zhang // SPIR-V ops' maximal version requirements.
1739414db10SLei Zhang
1749414db10SLei Zhang auto triple = spirv::VerCapExtAttr::get(
1759414db10SLei Zhang deducedVersion, deducedCapabilities.getArrayRef(),
1769414db10SLei Zhang deducedExtensions.getArrayRef(), &getContext());
1771ffc1aaaSChristian Sigg module->setAttr(spirv::ModuleOp::getVCETripleAttrName(), triple);
1789414db10SLei Zhang }
1799414db10SLei Zhang
18080aca1eaSRiver Riddle std::unique_ptr<OperationPass<spirv::ModuleOp>>
createUpdateVersionCapabilityExtensionPass()1819414db10SLei Zhang mlir::spirv::createUpdateVersionCapabilityExtensionPass() {
1829414db10SLei Zhang return std::make_unique<UpdateVCEPass>();
1839414db10SLei Zhang }
184