1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
11 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/Pass/Pass.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // Printing op availability pass
19 //===----------------------------------------------------------------------===//
20 
21 namespace {
22 /// A pass for testing SPIR-V op availability.
23 struct PrintOpAvailability
24     : public PassWrapper<PrintOpAvailability, OperationPass<FuncOp>> {
25   void runOnOperation() override;
26   StringRef getArgument() const final { return "test-spirv-op-availability"; }
27   StringRef getDescription() const final {
28     return "Test SPIR-V op availability";
29   }
30 };
31 } // namespace
32 
33 void PrintOpAvailability::runOnOperation() {
34   auto f = getOperation();
35   llvm::outs() << f.getName() << "\n";
36 
37   Dialect *spvDialect = getContext().getLoadedDialect("spv");
38 
39   f->walk([&](Operation *op) {
40     if (op->getDialect() != spvDialect)
41       return WalkResult::advance();
42 
43     auto opName = op->getName();
44     auto &os = llvm::outs();
45 
46     if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
47       Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
48       os << opName << " min version: ";
49       if (minVersion)
50         os << spirv::stringifyVersion(*minVersion) << "\n";
51       else
52         os << "None\n";
53     }
54 
55     if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
56       Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
57       os << opName << " max version: ";
58       if (maxVersion)
59         os << spirv::stringifyVersion(*maxVersion) << "\n";
60       else
61         os << "None\n";
62     }
63 
64     if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
65       os << opName << " extensions: [";
66       for (const auto &exts : extension.getExtensions()) {
67         os << " [";
68         llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
69           os << spirv::stringifyExtension(ext);
70         });
71         os << "]";
72       }
73       os << " ]\n";
74     }
75 
76     if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
77       os << opName << " capabilities: [";
78       for (const auto &caps : capability.getCapabilities()) {
79         os << " [";
80         llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
81           os << spirv::stringifyCapability(cap);
82         });
83         os << "]";
84       }
85       os << " ]\n";
86     }
87     os.flush();
88 
89     return WalkResult::advance();
90   });
91 }
92 
93 namespace mlir {
94 void registerPrintSpirvAvailabilityPass() {
95   PassRegistration<PrintOpAvailability>();
96 }
97 } // namespace mlir
98 
99 //===----------------------------------------------------------------------===//
100 // Converting target environment pass
101 //===----------------------------------------------------------------------===//
102 
103 namespace {
104 /// A pass for testing SPIR-V op availability.
105 struct ConvertToTargetEnv
106     : public PassWrapper<ConvertToTargetEnv, OperationPass<FuncOp>> {
107   StringRef getArgument() const override { return "test-spirv-target-env"; }
108   StringRef getDescription() const override {
109     return "Test SPIR-V target environment";
110   }
111   void runOnOperation() override;
112 };
113 
114 struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
115   ConvertToAtomCmpExchangeWeak(MLIRContext *context);
116   LogicalResult matchAndRewrite(Operation *op,
117                                 PatternRewriter &rewriter) const override;
118 };
119 
120 struct ConvertToBitReverse : public RewritePattern {
121   ConvertToBitReverse(MLIRContext *context);
122   LogicalResult matchAndRewrite(Operation *op,
123                                 PatternRewriter &rewriter) const override;
124 };
125 
126 struct ConvertToGroupNonUniformBallot : public RewritePattern {
127   ConvertToGroupNonUniformBallot(MLIRContext *context);
128   LogicalResult matchAndRewrite(Operation *op,
129                                 PatternRewriter &rewriter) const override;
130 };
131 
132 struct ConvertToModule : public RewritePattern {
133   ConvertToModule(MLIRContext *context);
134   LogicalResult matchAndRewrite(Operation *op,
135                                 PatternRewriter &rewriter) const override;
136 };
137 
138 struct ConvertToSubgroupBallot : public RewritePattern {
139   ConvertToSubgroupBallot(MLIRContext *context);
140   LogicalResult matchAndRewrite(Operation *op,
141                                 PatternRewriter &rewriter) const override;
142 };
143 } // namespace
144 
145 void ConvertToTargetEnv::runOnOperation() {
146   MLIRContext *context = &getContext();
147   FuncOp fn = getOperation();
148 
149   auto targetEnv = fn.getOperation()
150                        ->getAttr(spirv::getTargetEnvAttrName())
151                        .cast<spirv::TargetEnvAttr>();
152   if (!targetEnv) {
153     fn.emitError("missing 'spv.target_env' attribute");
154     return signalPassFailure();
155   }
156 
157   auto target = SPIRVConversionTarget::get(targetEnv);
158 
159   RewritePatternSet patterns(context);
160   patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
161                ConvertToGroupNonUniformBallot, ConvertToModule,
162                ConvertToSubgroupBallot>(context);
163 
164   if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
165     return signalPassFailure();
166 }
167 
168 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
169     : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
170                      context, {"spv.AtomicCompareExchangeWeak"}) {}
171 
172 LogicalResult
173 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
174                                               PatternRewriter &rewriter) const {
175   Value ptr = op->getOperand(0);
176   Value value = op->getOperand(1);
177   Value comparator = op->getOperand(2);
178 
179   // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in
180   // memory semantics to additionally require AtomicStorage capability.
181   rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
182       op, value.getType(), ptr, spirv::Scope::Workgroup,
183       spirv::MemorySemantics::AcquireRelease |
184           spirv::MemorySemantics::AtomicCounterMemory,
185       spirv::MemorySemantics::Acquire, value, comparator);
186   return success();
187 }
188 
189 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
190     : RewritePattern("test.convert_to_bit_reverse_op", 1, context,
191                      {"spv.BitReverse"}) {}
192 
193 LogicalResult
194 ConvertToBitReverse::matchAndRewrite(Operation *op,
195                                      PatternRewriter &rewriter) const {
196   Value predicate = op->getOperand(0);
197 
198   rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
199       op, op->getResult(0).getType(), predicate);
200   return success();
201 }
202 
203 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
204     MLIRContext *context)
205     : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1, context,
206                      {"spv.GroupNonUniformBallot"}) {}
207 
208 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
209     Operation *op, PatternRewriter &rewriter) const {
210   Value predicate = op->getOperand(0);
211 
212   rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
213       op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
214   return success();
215 }
216 
217 ConvertToModule::ConvertToModule(MLIRContext *context)
218     : RewritePattern("test.convert_to_module_op", 1, context, {"spv.module"}) {}
219 
220 LogicalResult
221 ConvertToModule::matchAndRewrite(Operation *op,
222                                  PatternRewriter &rewriter) const {
223   rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
224       op, spirv::AddressingModel::PhysicalStorageBuffer64,
225       spirv::MemoryModel::Vulkan);
226   return success();
227 }
228 
229 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
230     : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
231                      {"spv.SubgroupBallotKHR"}) {}
232 
233 LogicalResult
234 ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
235                                          PatternRewriter &rewriter) const {
236   Value predicate = op->getOperand(0);
237 
238   rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
239       op, op->getResult(0).getType(), predicate);
240   return success();
241 }
242 
243 namespace mlir {
244 void registerConvertToTargetEnvPass() {
245   PassRegistration<ConvertToTargetEnv>();
246 }
247 } // namespace mlir
248