1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Pass/Pass.h"
15
16 using namespace mlir;
17
18 //===----------------------------------------------------------------------===//
19 // Printing op availability pass
20 //===----------------------------------------------------------------------===//
21
22 namespace {
23 /// A pass for testing SPIR-V op availability.
24 struct PrintOpAvailability
25 : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
26 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
27
28 void runOnOperation() override;
getArgument__anonf839eacb0111::PrintOpAvailability29 StringRef getArgument() const final { return "test-spirv-op-availability"; }
getDescription__anonf839eacb0111::PrintOpAvailability30 StringRef getDescription() const final {
31 return "Test SPIR-V op availability";
32 }
33 };
34 } // namespace
35
runOnOperation()36 void PrintOpAvailability::runOnOperation() {
37 auto f = getOperation();
38 llvm::outs() << f.getName() << "\n";
39
40 Dialect *spvDialect = getContext().getLoadedDialect("spv");
41
42 f->walk([&](Operation *op) {
43 if (op->getDialect() != spvDialect)
44 return WalkResult::advance();
45
46 auto opName = op->getName();
47 auto &os = llvm::outs();
48
49 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
50 Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
51 os << opName << " min version: ";
52 if (minVersion)
53 os << spirv::stringifyVersion(*minVersion) << "\n";
54 else
55 os << "None\n";
56 }
57
58 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
59 Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
60 os << opName << " max version: ";
61 if (maxVersion)
62 os << spirv::stringifyVersion(*maxVersion) << "\n";
63 else
64 os << "None\n";
65 }
66
67 if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
68 os << opName << " extensions: [";
69 for (const auto &exts : extension.getExtensions()) {
70 os << " [";
71 llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
72 os << spirv::stringifyExtension(ext);
73 });
74 os << "]";
75 }
76 os << " ]\n";
77 }
78
79 if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
80 os << opName << " capabilities: [";
81 for (const auto &caps : capability.getCapabilities()) {
82 os << " [";
83 llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
84 os << spirv::stringifyCapability(cap);
85 });
86 os << "]";
87 }
88 os << " ]\n";
89 }
90 os.flush();
91
92 return WalkResult::advance();
93 });
94 }
95
96 namespace mlir {
registerPrintSpirvAvailabilityPass()97 void registerPrintSpirvAvailabilityPass() {
98 PassRegistration<PrintOpAvailability>();
99 }
100 } // namespace mlir
101
102 //===----------------------------------------------------------------------===//
103 // Converting target environment pass
104 //===----------------------------------------------------------------------===//
105
106 namespace {
107 /// A pass for testing SPIR-V op availability.
108 struct ConvertToTargetEnv
109 : public PassWrapper<ConvertToTargetEnv, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonf839eacb0511::ConvertToTargetEnv110 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToTargetEnv)
111
112 StringRef getArgument() const override { return "test-spirv-target-env"; }
getDescription__anonf839eacb0511::ConvertToTargetEnv113 StringRef getDescription() const override {
114 return "Test SPIR-V target environment";
115 }
116 void runOnOperation() override;
117 };
118
119 struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
120 ConvertToAtomCmpExchangeWeak(MLIRContext *context);
121 LogicalResult matchAndRewrite(Operation *op,
122 PatternRewriter &rewriter) const override;
123 };
124
125 struct ConvertToBitReverse : public RewritePattern {
126 ConvertToBitReverse(MLIRContext *context);
127 LogicalResult matchAndRewrite(Operation *op,
128 PatternRewriter &rewriter) const override;
129 };
130
131 struct ConvertToGroupNonUniformBallot : public RewritePattern {
132 ConvertToGroupNonUniformBallot(MLIRContext *context);
133 LogicalResult matchAndRewrite(Operation *op,
134 PatternRewriter &rewriter) const override;
135 };
136
137 struct ConvertToModule : public RewritePattern {
138 ConvertToModule(MLIRContext *context);
139 LogicalResult matchAndRewrite(Operation *op,
140 PatternRewriter &rewriter) const override;
141 };
142
143 struct ConvertToSubgroupBallot : public RewritePattern {
144 ConvertToSubgroupBallot(MLIRContext *context);
145 LogicalResult matchAndRewrite(Operation *op,
146 PatternRewriter &rewriter) const override;
147 };
148 } // namespace
149
runOnOperation()150 void ConvertToTargetEnv::runOnOperation() {
151 MLIRContext *context = &getContext();
152 func::FuncOp fn = getOperation();
153
154 auto targetEnv = fn.getOperation()
155 ->getAttr(spirv::getTargetEnvAttrName())
156 .cast<spirv::TargetEnvAttr>();
157 if (!targetEnv) {
158 fn.emitError("missing 'spv.target_env' attribute");
159 return signalPassFailure();
160 }
161
162 auto target = SPIRVConversionTarget::get(targetEnv);
163
164 RewritePatternSet patterns(context);
165 patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
166 ConvertToGroupNonUniformBallot, ConvertToModule,
167 ConvertToSubgroupBallot>(context);
168
169 if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
170 return signalPassFailure();
171 }
172
ConvertToAtomCmpExchangeWeak(MLIRContext * context)173 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
174 : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
175 context, {"spv.AtomicCompareExchangeWeak"}) {}
176
177 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const178 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
179 PatternRewriter &rewriter) const {
180 Value ptr = op->getOperand(0);
181 Value value = op->getOperand(1);
182 Value comparator = op->getOperand(2);
183
184 // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in
185 // memory semantics to additionally require AtomicStorage capability.
186 rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
187 op, value.getType(), ptr, spirv::Scope::Workgroup,
188 spirv::MemorySemantics::AcquireRelease |
189 spirv::MemorySemantics::AtomicCounterMemory,
190 spirv::MemorySemantics::Acquire, value, comparator);
191 return success();
192 }
193
ConvertToBitReverse(MLIRContext * context)194 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
195 : RewritePattern("test.convert_to_bit_reverse_op", 1, context,
196 {"spv.BitReverse"}) {}
197
198 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const199 ConvertToBitReverse::matchAndRewrite(Operation *op,
200 PatternRewriter &rewriter) const {
201 Value predicate = op->getOperand(0);
202
203 rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
204 op, op->getResult(0).getType(), predicate);
205 return success();
206 }
207
ConvertToGroupNonUniformBallot(MLIRContext * context)208 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
209 MLIRContext *context)
210 : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1, context,
211 {"spv.GroupNonUniformBallot"}) {}
212
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const213 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
214 Operation *op, PatternRewriter &rewriter) const {
215 Value predicate = op->getOperand(0);
216
217 rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
218 op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
219 return success();
220 }
221
ConvertToModule(MLIRContext * context)222 ConvertToModule::ConvertToModule(MLIRContext *context)
223 : RewritePattern("test.convert_to_module_op", 1, context, {"spv.module"}) {}
224
225 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const226 ConvertToModule::matchAndRewrite(Operation *op,
227 PatternRewriter &rewriter) const {
228 rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
229 op, spirv::AddressingModel::PhysicalStorageBuffer64,
230 spirv::MemoryModel::Vulkan);
231 return success();
232 }
233
ConvertToSubgroupBallot(MLIRContext * context)234 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
235 : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
236 {"spv.SubgroupBallotKHR"}) {}
237
238 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const239 ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
240 PatternRewriter &rewriter) const {
241 Value predicate = op->getOperand(0);
242
243 rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
244 op, op->getResult(0).getType(), predicate);
245 return success();
246 }
247
248 namespace mlir {
registerConvertToTargetEnvPass()249 void registerConvertToTargetEnvPass() {
250 PassRegistration<ConvertToTargetEnv>();
251 }
252 } // namespace mlir
253