1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 10 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 11 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 12 #include "mlir/IR/BuiltinOps.h" 13 #include "mlir/Pass/Pass.h" 14 15 using namespace mlir; 16 17 //===----------------------------------------------------------------------===// 18 // Printing op availability pass 19 //===----------------------------------------------------------------------===// 20 21 namespace { 22 /// A pass for testing SPIR-V op availability. 23 struct PrintOpAvailability 24 : public PassWrapper<PrintOpAvailability, OperationPass<FuncOp>> { 25 void runOnOperation() override; 26 StringRef getArgument() const final { return "test-spirv-op-availability"; } 27 StringRef getDescription() const final { 28 return "Test SPIR-V op availability"; 29 } 30 }; 31 } // namespace 32 33 void PrintOpAvailability::runOnOperation() { 34 auto f = getOperation(); 35 llvm::outs() << f.getName() << "\n"; 36 37 Dialect *spvDialect = getContext().getLoadedDialect("spv"); 38 39 f->walk([&](Operation *op) { 40 if (op->getDialect() != spvDialect) 41 return WalkResult::advance(); 42 43 auto opName = op->getName(); 44 auto &os = llvm::outs(); 45 46 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) { 47 Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion(); 48 os << opName << " min version: "; 49 if (minVersion) 50 os << spirv::stringifyVersion(*minVersion) << "\n"; 51 else 52 os << "None\n"; 53 } 54 55 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) { 56 Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion(); 57 os << opName << " max version: "; 58 if (maxVersion) 59 os << spirv::stringifyVersion(*maxVersion) << "\n"; 60 else 61 os << "None\n"; 62 } 63 64 if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) { 65 os << opName << " extensions: ["; 66 for (const auto &exts : extension.getExtensions()) { 67 os << " ["; 68 llvm::interleaveComma(exts, os, [&](spirv::Extension ext) { 69 os << spirv::stringifyExtension(ext); 70 }); 71 os << "]"; 72 } 73 os << " ]\n"; 74 } 75 76 if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) { 77 os << opName << " capabilities: ["; 78 for (const auto &caps : capability.getCapabilities()) { 79 os << " ["; 80 llvm::interleaveComma(caps, os, [&](spirv::Capability cap) { 81 os << spirv::stringifyCapability(cap); 82 }); 83 os << "]"; 84 } 85 os << " ]\n"; 86 } 87 os.flush(); 88 89 return WalkResult::advance(); 90 }); 91 } 92 93 namespace mlir { 94 void registerPrintSpirvAvailabilityPass() { 95 PassRegistration<PrintOpAvailability>(); 96 } 97 } // namespace mlir 98 99 //===----------------------------------------------------------------------===// 100 // Converting target environment pass 101 //===----------------------------------------------------------------------===// 102 103 namespace { 104 /// A pass for testing SPIR-V op availability. 105 struct ConvertToTargetEnv 106 : public PassWrapper<ConvertToTargetEnv, OperationPass<FuncOp>> { 107 StringRef getArgument() const override { return "test-spirv-target-env"; } 108 StringRef getDescription() const override { 109 return "Test SPIR-V target environment"; 110 } 111 void runOnOperation() override; 112 }; 113 114 struct ConvertToAtomCmpExchangeWeak : public RewritePattern { 115 ConvertToAtomCmpExchangeWeak(MLIRContext *context); 116 LogicalResult matchAndRewrite(Operation *op, 117 PatternRewriter &rewriter) const override; 118 }; 119 120 struct ConvertToBitReverse : public RewritePattern { 121 ConvertToBitReverse(MLIRContext *context); 122 LogicalResult matchAndRewrite(Operation *op, 123 PatternRewriter &rewriter) const override; 124 }; 125 126 struct ConvertToGroupNonUniformBallot : public RewritePattern { 127 ConvertToGroupNonUniformBallot(MLIRContext *context); 128 LogicalResult matchAndRewrite(Operation *op, 129 PatternRewriter &rewriter) const override; 130 }; 131 132 struct ConvertToModule : public RewritePattern { 133 ConvertToModule(MLIRContext *context); 134 LogicalResult matchAndRewrite(Operation *op, 135 PatternRewriter &rewriter) const override; 136 }; 137 138 struct ConvertToSubgroupBallot : public RewritePattern { 139 ConvertToSubgroupBallot(MLIRContext *context); 140 LogicalResult matchAndRewrite(Operation *op, 141 PatternRewriter &rewriter) const override; 142 }; 143 } // namespace 144 145 void ConvertToTargetEnv::runOnOperation() { 146 MLIRContext *context = &getContext(); 147 FuncOp fn = getOperation(); 148 149 auto targetEnv = fn.getOperation() 150 ->getAttr(spirv::getTargetEnvAttrName()) 151 .cast<spirv::TargetEnvAttr>(); 152 if (!targetEnv) { 153 fn.emitError("missing 'spv.target_env' attribute"); 154 return signalPassFailure(); 155 } 156 157 auto target = SPIRVConversionTarget::get(targetEnv); 158 159 RewritePatternSet patterns(context); 160 patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse, 161 ConvertToGroupNonUniformBallot, ConvertToModule, 162 ConvertToSubgroupBallot>(context); 163 164 if (failed(applyPartialConversion(fn, *target, std::move(patterns)))) 165 return signalPassFailure(); 166 } 167 168 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context) 169 : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1, 170 context, {"spv.AtomicCompareExchangeWeak"}) {} 171 172 LogicalResult 173 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op, 174 PatternRewriter &rewriter) const { 175 Value ptr = op->getOperand(0); 176 Value value = op->getOperand(1); 177 Value comparator = op->getOperand(2); 178 179 // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in 180 // memory semantics to additionally require AtomicStorage capability. 181 rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>( 182 op, value.getType(), ptr, spirv::Scope::Workgroup, 183 spirv::MemorySemantics::AcquireRelease | 184 spirv::MemorySemantics::AtomicCounterMemory, 185 spirv::MemorySemantics::Acquire, value, comparator); 186 return success(); 187 } 188 189 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context) 190 : RewritePattern("test.convert_to_bit_reverse_op", 1, context, 191 {"spv.BitReverse"}) {} 192 193 LogicalResult 194 ConvertToBitReverse::matchAndRewrite(Operation *op, 195 PatternRewriter &rewriter) const { 196 Value predicate = op->getOperand(0); 197 198 rewriter.replaceOpWithNewOp<spirv::BitReverseOp>( 199 op, op->getResult(0).getType(), predicate); 200 return success(); 201 } 202 203 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot( 204 MLIRContext *context) 205 : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1, context, 206 {"spv.GroupNonUniformBallot"}) {} 207 208 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite( 209 Operation *op, PatternRewriter &rewriter) const { 210 Value predicate = op->getOperand(0); 211 212 rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>( 213 op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate); 214 return success(); 215 } 216 217 ConvertToModule::ConvertToModule(MLIRContext *context) 218 : RewritePattern("test.convert_to_module_op", 1, context, {"spv.module"}) {} 219 220 LogicalResult 221 ConvertToModule::matchAndRewrite(Operation *op, 222 PatternRewriter &rewriter) const { 223 rewriter.replaceOpWithNewOp<spirv::ModuleOp>( 224 op, spirv::AddressingModel::PhysicalStorageBuffer64, 225 spirv::MemoryModel::Vulkan); 226 return success(); 227 } 228 229 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context) 230 : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context, 231 {"spv.SubgroupBallotKHR"}) {} 232 233 LogicalResult 234 ConvertToSubgroupBallot::matchAndRewrite(Operation *op, 235 PatternRewriter &rewriter) const { 236 Value predicate = op->getOperand(0); 237 238 rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>( 239 op, op->getResult(0).getType(), predicate); 240 return success(); 241 } 242 243 namespace mlir { 244 void registerConvertToTargetEnvPass() { 245 PassRegistration<ConvertToTargetEnv>(); 246 } 247 } // namespace mlir 248