1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
11 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/Pass/Pass.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // Printing op availability pass
19 //===----------------------------------------------------------------------===//
20 
21 namespace {
22 /// A pass for testing SPIR-V op availability.
23 struct PrintOpAvailability
24     : public PassWrapper<PrintOpAvailability, FunctionPass> {
25   void runOnFunction() override;
26   StringRef getArgument() const final { return "test-spirv-op-availability"; }
27   StringRef getDescription() const final {
28     return "Test SPIR-V op availability";
29   }
30 };
31 } // end anonymous namespace
32 
33 void PrintOpAvailability::runOnFunction() {
34   auto f = getFunction();
35   llvm::outs() << f.getName() << "\n";
36 
37   Dialect *spvDialect = getContext().getLoadedDialect("spv");
38 
39   f->walk([&](Operation *op) {
40     if (op->getDialect() != spvDialect)
41       return WalkResult::advance();
42 
43     auto opName = op->getName();
44     auto &os = llvm::outs();
45 
46     if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
47       os << opName << " min version: "
48          << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n";
49 
50     if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
51       os << opName << " max version: "
52          << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n";
53 
54     if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
55       os << opName << " extensions: [";
56       for (const auto &exts : extension.getExtensions()) {
57         os << " [";
58         llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
59           os << spirv::stringifyExtension(ext);
60         });
61         os << "]";
62       }
63       os << " ]\n";
64     }
65 
66     if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
67       os << opName << " capabilities: [";
68       for (const auto &caps : capability.getCapabilities()) {
69         os << " [";
70         llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
71           os << spirv::stringifyCapability(cap);
72         });
73         os << "]";
74       }
75       os << " ]\n";
76     }
77     os.flush();
78 
79     return WalkResult::advance();
80   });
81 }
82 
83 namespace mlir {
84 void registerPrintOpAvailabilityPass() {
85   PassRegistration<PrintOpAvailability>();
86 }
87 } // namespace mlir
88 
89 //===----------------------------------------------------------------------===//
90 // Converting target environment pass
91 //===----------------------------------------------------------------------===//
92 
93 namespace {
94 /// A pass for testing SPIR-V op availability.
95 struct ConvertToTargetEnv
96     : public PassWrapper<ConvertToTargetEnv, FunctionPass> {
97   StringRef getArgument() const override { return "test-spirv-target-env"; }
98   StringRef getDescription() const override {
99     return "Test SPIR-V target environment";
100   }
101   void runOnFunction() override;
102 };
103 
104 struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
105   ConvertToAtomCmpExchangeWeak(MLIRContext *context);
106   LogicalResult matchAndRewrite(Operation *op,
107                                 PatternRewriter &rewriter) const override;
108 };
109 
110 struct ConvertToBitReverse : public RewritePattern {
111   ConvertToBitReverse(MLIRContext *context);
112   LogicalResult matchAndRewrite(Operation *op,
113                                 PatternRewriter &rewriter) const override;
114 };
115 
116 struct ConvertToGroupNonUniformBallot : public RewritePattern {
117   ConvertToGroupNonUniformBallot(MLIRContext *context);
118   LogicalResult matchAndRewrite(Operation *op,
119                                 PatternRewriter &rewriter) const override;
120 };
121 
122 struct ConvertToModule : public RewritePattern {
123   ConvertToModule(MLIRContext *context);
124   LogicalResult matchAndRewrite(Operation *op,
125                                 PatternRewriter &rewriter) const override;
126 };
127 
128 struct ConvertToSubgroupBallot : public RewritePattern {
129   ConvertToSubgroupBallot(MLIRContext *context);
130   LogicalResult matchAndRewrite(Operation *op,
131                                 PatternRewriter &rewriter) const override;
132 };
133 } // end anonymous namespace
134 
135 void ConvertToTargetEnv::runOnFunction() {
136   MLIRContext *context = &getContext();
137   FuncOp fn = getFunction();
138 
139   auto targetEnv = fn.getOperation()
140                        ->getAttr(spirv::getTargetEnvAttrName())
141                        .cast<spirv::TargetEnvAttr>();
142   if (!targetEnv) {
143     fn.emitError("missing 'spv.target_env' attribute");
144     return signalPassFailure();
145   }
146 
147   auto target = SPIRVConversionTarget::get(targetEnv);
148 
149   RewritePatternSet patterns(context);
150   patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
151                ConvertToGroupNonUniformBallot, ConvertToModule,
152                ConvertToSubgroupBallot>(context);
153 
154   if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
155     return signalPassFailure();
156 }
157 
158 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
159     : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
160                      context, {"spv.AtomicCompareExchangeWeak"}) {}
161 
162 LogicalResult
163 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
164                                               PatternRewriter &rewriter) const {
165   Value ptr = op->getOperand(0);
166   Value value = op->getOperand(1);
167   Value comparator = op->getOperand(2);
168 
169   // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in
170   // memory semantics to additionally require AtomicStorage capability.
171   rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
172       op, value.getType(), ptr, spirv::Scope::Workgroup,
173       spirv::MemorySemantics::AcquireRelease |
174           spirv::MemorySemantics::AtomicCounterMemory,
175       spirv::MemorySemantics::Acquire, value, comparator);
176   return success();
177 }
178 
179 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
180     : RewritePattern("test.convert_to_bit_reverse_op", 1, context,
181                      {"spv.BitReverse"}) {}
182 
183 LogicalResult
184 ConvertToBitReverse::matchAndRewrite(Operation *op,
185                                      PatternRewriter &rewriter) const {
186   Value predicate = op->getOperand(0);
187 
188   rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
189       op, op->getResult(0).getType(), predicate);
190   return success();
191 }
192 
193 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
194     MLIRContext *context)
195     : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1, context,
196                      {"spv.GroupNonUniformBallot"}) {}
197 
198 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
199     Operation *op, PatternRewriter &rewriter) const {
200   Value predicate = op->getOperand(0);
201 
202   rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
203       op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
204   return success();
205 }
206 
207 ConvertToModule::ConvertToModule(MLIRContext *context)
208     : RewritePattern("test.convert_to_module_op", 1, context, {"spv.module"}) {}
209 
210 LogicalResult
211 ConvertToModule::matchAndRewrite(Operation *op,
212                                  PatternRewriter &rewriter) const {
213   rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
214       op, spirv::AddressingModel::PhysicalStorageBuffer64,
215       spirv::MemoryModel::Vulkan);
216   return success();
217 }
218 
219 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
220     : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
221                      {"spv.SubgroupBallotKHR"}) {}
222 
223 LogicalResult
224 ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
225                                          PatternRewriter &rewriter) const {
226   Value predicate = op->getOperand(0);
227 
228   rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
229       op, op->getResult(0).getType(), predicate);
230   return success();
231 }
232 
233 namespace mlir {
234 void registerConvertToTargetEnvPass() {
235   PassRegistration<ConvertToTargetEnv>();
236 }
237 } // namespace mlir
238