1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Func/IR/FuncOps.h"
10 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
11 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/Pass/Pass.h"
15 
16 using namespace mlir;
17 
18 //===----------------------------------------------------------------------===//
19 // Printing op availability pass
20 //===----------------------------------------------------------------------===//
21 
22 namespace {
23 /// A pass for testing SPIR-V op availability.
24 struct PrintOpAvailability
25     : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> {
26   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability)
27 
28   void runOnOperation() override;
getArgument__anonf839eacb0111::PrintOpAvailability29   StringRef getArgument() const final { return "test-spirv-op-availability"; }
getDescription__anonf839eacb0111::PrintOpAvailability30   StringRef getDescription() const final {
31     return "Test SPIR-V op availability";
32   }
33 };
34 } // namespace
35 
runOnOperation()36 void PrintOpAvailability::runOnOperation() {
37   auto f = getOperation();
38   llvm::outs() << f.getName() << "\n";
39 
40   Dialect *spvDialect = getContext().getLoadedDialect("spv");
41 
42   f->walk([&](Operation *op) {
43     if (op->getDialect() != spvDialect)
44       return WalkResult::advance();
45 
46     auto opName = op->getName();
47     auto &os = llvm::outs();
48 
49     if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) {
50       Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion();
51       os << opName << " min version: ";
52       if (minVersion)
53         os << spirv::stringifyVersion(*minVersion) << "\n";
54       else
55         os << "None\n";
56     }
57 
58     if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) {
59       Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion();
60       os << opName << " max version: ";
61       if (maxVersion)
62         os << spirv::stringifyVersion(*maxVersion) << "\n";
63       else
64         os << "None\n";
65     }
66 
67     if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) {
68       os << opName << " extensions: [";
69       for (const auto &exts : extension.getExtensions()) {
70         os << " [";
71         llvm::interleaveComma(exts, os, [&](spirv::Extension ext) {
72           os << spirv::stringifyExtension(ext);
73         });
74         os << "]";
75       }
76       os << " ]\n";
77     }
78 
79     if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) {
80       os << opName << " capabilities: [";
81       for (const auto &caps : capability.getCapabilities()) {
82         os << " [";
83         llvm::interleaveComma(caps, os, [&](spirv::Capability cap) {
84           os << spirv::stringifyCapability(cap);
85         });
86         os << "]";
87       }
88       os << " ]\n";
89     }
90     os.flush();
91 
92     return WalkResult::advance();
93   });
94 }
95 
96 namespace mlir {
registerPrintSpirvAvailabilityPass()97 void registerPrintSpirvAvailabilityPass() {
98   PassRegistration<PrintOpAvailability>();
99 }
100 } // namespace mlir
101 
102 //===----------------------------------------------------------------------===//
103 // Converting target environment pass
104 //===----------------------------------------------------------------------===//
105 
106 namespace {
107 /// A pass for testing SPIR-V op availability.
108 struct ConvertToTargetEnv
109     : public PassWrapper<ConvertToTargetEnv, OperationPass<func::FuncOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID__anonf839eacb0511::ConvertToTargetEnv110   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToTargetEnv)
111 
112   StringRef getArgument() const override { return "test-spirv-target-env"; }
getDescription__anonf839eacb0511::ConvertToTargetEnv113   StringRef getDescription() const override {
114     return "Test SPIR-V target environment";
115   }
116   void runOnOperation() override;
117 };
118 
119 struct ConvertToAtomCmpExchangeWeak : public RewritePattern {
120   ConvertToAtomCmpExchangeWeak(MLIRContext *context);
121   LogicalResult matchAndRewrite(Operation *op,
122                                 PatternRewriter &rewriter) const override;
123 };
124 
125 struct ConvertToBitReverse : public RewritePattern {
126   ConvertToBitReverse(MLIRContext *context);
127   LogicalResult matchAndRewrite(Operation *op,
128                                 PatternRewriter &rewriter) const override;
129 };
130 
131 struct ConvertToGroupNonUniformBallot : public RewritePattern {
132   ConvertToGroupNonUniformBallot(MLIRContext *context);
133   LogicalResult matchAndRewrite(Operation *op,
134                                 PatternRewriter &rewriter) const override;
135 };
136 
137 struct ConvertToModule : public RewritePattern {
138   ConvertToModule(MLIRContext *context);
139   LogicalResult matchAndRewrite(Operation *op,
140                                 PatternRewriter &rewriter) const override;
141 };
142 
143 struct ConvertToSubgroupBallot : public RewritePattern {
144   ConvertToSubgroupBallot(MLIRContext *context);
145   LogicalResult matchAndRewrite(Operation *op,
146                                 PatternRewriter &rewriter) const override;
147 };
148 } // namespace
149 
runOnOperation()150 void ConvertToTargetEnv::runOnOperation() {
151   MLIRContext *context = &getContext();
152   func::FuncOp fn = getOperation();
153 
154   auto targetEnv = fn.getOperation()
155                        ->getAttr(spirv::getTargetEnvAttrName())
156                        .cast<spirv::TargetEnvAttr>();
157   if (!targetEnv) {
158     fn.emitError("missing 'spv.target_env' attribute");
159     return signalPassFailure();
160   }
161 
162   auto target = SPIRVConversionTarget::get(targetEnv);
163 
164   RewritePatternSet patterns(context);
165   patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse,
166                ConvertToGroupNonUniformBallot, ConvertToModule,
167                ConvertToSubgroupBallot>(context);
168 
169   if (failed(applyPartialConversion(fn, *target, std::move(patterns))))
170     return signalPassFailure();
171 }
172 
ConvertToAtomCmpExchangeWeak(MLIRContext * context)173 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context)
174     : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1,
175                      context, {"spv.AtomicCompareExchangeWeak"}) {}
176 
177 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const178 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op,
179                                               PatternRewriter &rewriter) const {
180   Value ptr = op->getOperand(0);
181   Value value = op->getOperand(1);
182   Value comparator = op->getOperand(2);
183 
184   // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in
185   // memory semantics to additionally require AtomicStorage capability.
186   rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>(
187       op, value.getType(), ptr, spirv::Scope::Workgroup,
188       spirv::MemorySemantics::AcquireRelease |
189           spirv::MemorySemantics::AtomicCounterMemory,
190       spirv::MemorySemantics::Acquire, value, comparator);
191   return success();
192 }
193 
ConvertToBitReverse(MLIRContext * context)194 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context)
195     : RewritePattern("test.convert_to_bit_reverse_op", 1, context,
196                      {"spv.BitReverse"}) {}
197 
198 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const199 ConvertToBitReverse::matchAndRewrite(Operation *op,
200                                      PatternRewriter &rewriter) const {
201   Value predicate = op->getOperand(0);
202 
203   rewriter.replaceOpWithNewOp<spirv::BitReverseOp>(
204       op, op->getResult(0).getType(), predicate);
205   return success();
206 }
207 
ConvertToGroupNonUniformBallot(MLIRContext * context)208 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot(
209     MLIRContext *context)
210     : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1, context,
211                      {"spv.GroupNonUniformBallot"}) {}
212 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const213 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite(
214     Operation *op, PatternRewriter &rewriter) const {
215   Value predicate = op->getOperand(0);
216 
217   rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>(
218       op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate);
219   return success();
220 }
221 
ConvertToModule(MLIRContext * context)222 ConvertToModule::ConvertToModule(MLIRContext *context)
223     : RewritePattern("test.convert_to_module_op", 1, context, {"spv.module"}) {}
224 
225 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const226 ConvertToModule::matchAndRewrite(Operation *op,
227                                  PatternRewriter &rewriter) const {
228   rewriter.replaceOpWithNewOp<spirv::ModuleOp>(
229       op, spirv::AddressingModel::PhysicalStorageBuffer64,
230       spirv::MemoryModel::Vulkan);
231   return success();
232 }
233 
ConvertToSubgroupBallot(MLIRContext * context)234 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context)
235     : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context,
236                      {"spv.SubgroupBallotKHR"}) {}
237 
238 LogicalResult
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const239 ConvertToSubgroupBallot::matchAndRewrite(Operation *op,
240                                          PatternRewriter &rewriter) const {
241   Value predicate = op->getOperand(0);
242 
243   rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>(
244       op, op->getResult(0).getType(), predicate);
245   return success();
246 }
247 
248 namespace mlir {
registerConvertToTargetEnvPass()249 void registerConvertToTargetEnvPass() {
250   PassRegistration<ConvertToTargetEnv>();
251 }
252 } // namespace mlir
253