1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SPIRV/SPIRVLowering.h"
10 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
11 #include "mlir/Dialect/SPIRV/SPIRVTypes.h"
12 #include "mlir/IR/Function.h"
13 #include "mlir/Pass/Pass.h"
14 
15 using namespace mlir;
16 
17 //===----------------------------------------------------------------------===//
18 // Printing op availability pass
19 //===----------------------------------------------------------------------===//
20 
21 namespace {
22 /// A pass for testing SPIR-V op availability.
23 struct PrintOpAvailability : public FunctionPass<PrintOpAvailability> {
24   void runOnFunction() override;
25 };
26 } // end anonymous namespace
27 
28 void PrintOpAvailability::runOnFunction() {
29   auto f = getFunction();
30   llvm::outs() << f.getName() << "\n";
31 
32   Dialect *spvDialect = getContext().getRegisteredDialect("spv");
33 
34   f.getOperation()->walk([&](Operation *op) {
35     if (op->getDialect() != spvDialect)
36       return WalkResult::advance();
37 
38     auto opName = op->getName();
39     auto &os = llvm::outs();
40 
41     if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op))
42       os << opName << " min version: "
43          << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n";
44 
45     if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op))
46       os << opName << " max version: "
47          << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n";
48 
49     if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
50       os << opName << " extensions: [";
51       for (const auto &exts : extension.getExtensions()) {
52         os << " [";
53         interleaveComma(exts, os, [&](spirv::Extension ext) {
54           os << spirv::stringifyExtension(ext);
55         });
56         os << "]";
57       }
58       os << " ]\n";
59     }
60 
61     if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
62       os << opName << " capabilities: [";
63       for (const auto &caps : capability.getCapabilities()) {
64         os << " [";
65         interleaveComma(caps, os, [&](spirv::Capability cap) {
66           os << spirv::stringifyCapability(cap);
67         });
68         os << "]";
69       }
70       os << " ]\n";
71     }
72     os.flush();
73 
74     return WalkResult::advance();
75   });
76 }
77 
78 namespace mlir {
79 void registerPrintOpAvailabilityPass() {
80   PassRegistration<PrintOpAvailability> printOpAvailabilityPass(
81       "test-spirv-op-availability", "Test SPIR-V op availability");
82 }
83 } // namespace mlir
84 
85 //===----------------------------------------------------------------------===//
86 // Converting target environment pass
87 //===----------------------------------------------------------------------===//
88 
89 namespace {
90 /// A pass for testing SPIR-V op availability.
91 struct ConvertToTargetEnv : public FunctionPass<ConvertToTargetEnv> {
92   void runOnFunction() override;
93 };
94 
95 struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
96   ConvertToAtomCmpExchangeWeak(MLIRContext *context);
97   LogicalResult matchAndRewrite(Operation *op,
98                                 PatternRewriter &rewriter) const override;
99 };
100 
101 struct ConvertToBitReverse : public RewritePattern {
102   ConvertToBitReverse(MLIRContext *context);
103   LogicalResult matchAndRewrite(Operation *op,
104                                 PatternRewriter &rewriter) const override;
105 };
106 
107 struct ConvertToGroupNonUniformBallot : public RewritePattern {
108   ConvertToGroupNonUniformBallot(MLIRContext *context);
109   LogicalResult matchAndRewrite(Operation *op,
110                                 PatternRewriter &rewriter) const override;
111 };
112 
113 struct ConvertToModule : public RewritePattern {
114   ConvertToModule(MLIRContext *context);
115   LogicalResult matchAndRewrite(Operation *op,
116                                 PatternRewriter &rewriter) const override;
117 };
118 
119 struct ConvertToSubgroupBallot : public RewritePattern {
120   ConvertToSubgroupBallot(MLIRContext *context);
121   LogicalResult matchAndRewrite(Operation *op,
122                                 PatternRewriter &rewriter) const override;
123 };
124 } // end anonymous namespace
125 
126 void ConvertToTargetEnv::runOnFunction() {
127   MLIRContext *context = &getContext();
128   FuncOp fn = getFunction();
129 
130   auto targetEnv = fn.getOperation()
131                        ->getAttr(spirv::getTargetEnvAttrName())
132                        .cast<spirv::TargetEnvAttr>();
133   if (!targetEnv) {
134     fn.emitError("missing 'spv.target_env' attribute");
135     return signalPassFailure();
136   }
137 
138   auto target = spirv::SPIRVConversionTarget::get(targetEnv);
139 
140   OwningRewritePatternList patterns;
141   patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
142                   ConvertToGroupNonUniformBallot, ConvertToModule,
143                   ConvertToSubgroupBallot>(context);
144 
145   if (failed(applyPartialConversion(fn, *target, patterns)))
146     return signalPassFailure();
147 }
148 
149 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
150     : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op",
151                      {"spv.AtomicCompareExchangeWeak"}, 1, context) {}
152 
153 LogicalResult
154 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
155                                               PatternRewriter &rewriter) const {
156   Value ptr = op->getOperand(0);
157   Value value = op->getOperand(1);
158   Value comparator = op->getOperand(2);
159 
160   // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in
161   // memory semantics to additionally require AtomicStorage capability.
162   rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
163       op, value.getType(), ptr, spirv::Scope::Workgroup,
164       spirv::MemorySemantics::AcquireRelease |
165           spirv::MemorySemantics::AtomicCounterMemory,
166       spirv::MemorySemantics::Acquire, value, comparator);
167   return success();
168 }
169 
170 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
171     : RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1,
172                      context) {}
173 
174 LogicalResult
175 ConvertToBitReverse::matchAndRewrite(Operation *op,
176                                      PatternRewriter &rewriter) const {
177   Value predicate = op->getOperand(0);
178 
179   rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
180       op, op->getResult(0).getType(), predicate);
181   return success();
182 }
183 
184 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
185     MLIRContext *context)
186     : RewritePattern("test.convert_to_group_non_uniform_ballot_op",
187                      {"spv.GroupNonUniformBallot"}, 1, context) {}
188 
189 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
190     Operation *op, PatternRewriter &rewriter) const {
191   Value predicate = op->getOperand(0);
192 
193   rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
194       op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
195   return success();
196 }
197 
198 ConvertToModule::ConvertToModule(MLIRContext *context)
199     : RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {}
200 
201 LogicalResult
202 ConvertToModule::matchAndRewrite(Operation *op,
203                                  PatternRewriter &rewriter) const {
204   rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
205       op, spirv::AddressingModel::PhysicalStorageBuffer64,
206       spirv::MemoryModel::Vulkan);
207   return success();
208 }
209 
210 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
211     : RewritePattern("test.convert_to_subgroup_ballot_op",
212                      {"spv.SubgroupBallotKHR"}, 1, context) {}
213 
214 LogicalResult
215 ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
216                                          PatternRewriter &rewriter) const {
217   Value predicate = op->getOperand(0);
218 
219   rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
220       op, op->getResult(0).getType(), predicate);
221   return success();
222 }
223 
224 namespace mlir {
225 void registerConvertToTargetEnvPass() {
226   PassRegistration<ConvertToTargetEnv> convertToTargetEnvPass(
227       "test-spirv-target-env", "Test SPIR-V target environment");
228 }
229 } // namespace mlir
230