1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/Func/IR/FuncOps.h" 10 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 11 #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" 12 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "mlir/Pass/Pass.h" 15 16 using namespace mlir; 17 18 //===----------------------------------------------------------------------===// 19 // Printing op availability pass 20 //===----------------------------------------------------------------------===// 21 22 namespace { 23 /// A pass for testing SPIR-V op availability. 24 struct PrintOpAvailability 25 : public PassWrapper<PrintOpAvailability, OperationPass<func::FuncOp>> { 26 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrintOpAvailability) 27 28 void runOnOperation() override; 29 StringRef getArgument() const final { return "test-spirv-op-availability"; } 30 StringRef getDescription() const final { 31 return "Test SPIR-V op availability"; 32 } 33 }; 34 } // namespace 35 36 void PrintOpAvailability::runOnOperation() { 37 auto f = getOperation(); 38 llvm::outs() << f.getName() << "\n"; 39 40 Dialect *spvDialect = getContext().getLoadedDialect("spv"); 41 42 f->walk([&](Operation *op) { 43 if (op->getDialect() != spvDialect) 44 return WalkResult::advance(); 45 46 auto opName = op->getName(); 47 auto &os = llvm::outs(); 48 49 if (auto minVersionIfx = dyn_cast<spirv::QueryMinVersionInterface>(op)) { 50 Optional<spirv::Version> minVersion = minVersionIfx.getMinVersion(); 51 os << opName << " min version: "; 52 if (minVersion) 53 os << spirv::stringifyVersion(*minVersion) << "\n"; 54 else 55 os << "None\n"; 56 } 57 58 if (auto maxVersionIfx = dyn_cast<spirv::QueryMaxVersionInterface>(op)) { 59 Optional<spirv::Version> maxVersion = maxVersionIfx.getMaxVersion(); 60 os << opName << " max version: "; 61 if (maxVersion) 62 os << spirv::stringifyVersion(*maxVersion) << "\n"; 63 else 64 os << "None\n"; 65 } 66 67 if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) { 68 os << opName << " extensions: ["; 69 for (const auto &exts : extension.getExtensions()) { 70 os << " ["; 71 llvm::interleaveComma(exts, os, [&](spirv::Extension ext) { 72 os << spirv::stringifyExtension(ext); 73 }); 74 os << "]"; 75 } 76 os << " ]\n"; 77 } 78 79 if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) { 80 os << opName << " capabilities: ["; 81 for (const auto &caps : capability.getCapabilities()) { 82 os << " ["; 83 llvm::interleaveComma(caps, os, [&](spirv::Capability cap) { 84 os << spirv::stringifyCapability(cap); 85 }); 86 os << "]"; 87 } 88 os << " ]\n"; 89 } 90 os.flush(); 91 92 return WalkResult::advance(); 93 }); 94 } 95 96 namespace mlir { 97 void registerPrintSpirvAvailabilityPass() { 98 PassRegistration<PrintOpAvailability>(); 99 } 100 } // namespace mlir 101 102 //===----------------------------------------------------------------------===// 103 // Converting target environment pass 104 //===----------------------------------------------------------------------===// 105 106 namespace { 107 /// A pass for testing SPIR-V op availability. 108 struct ConvertToTargetEnv 109 : public PassWrapper<ConvertToTargetEnv, OperationPass<func::FuncOp>> { 110 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertToTargetEnv) 111 112 StringRef getArgument() const override { return "test-spirv-target-env"; } 113 StringRef getDescription() const override { 114 return "Test SPIR-V target environment"; 115 } 116 void runOnOperation() override; 117 }; 118 119 struct ConvertToAtomCmpExchangeWeak : public RewritePattern { 120 ConvertToAtomCmpExchangeWeak(MLIRContext *context); 121 LogicalResult matchAndRewrite(Operation *op, 122 PatternRewriter &rewriter) const override; 123 }; 124 125 struct ConvertToBitReverse : public RewritePattern { 126 ConvertToBitReverse(MLIRContext *context); 127 LogicalResult matchAndRewrite(Operation *op, 128 PatternRewriter &rewriter) const override; 129 }; 130 131 struct ConvertToGroupNonUniformBallot : public RewritePattern { 132 ConvertToGroupNonUniformBallot(MLIRContext *context); 133 LogicalResult matchAndRewrite(Operation *op, 134 PatternRewriter &rewriter) const override; 135 }; 136 137 struct ConvertToModule : public RewritePattern { 138 ConvertToModule(MLIRContext *context); 139 LogicalResult matchAndRewrite(Operation *op, 140 PatternRewriter &rewriter) const override; 141 }; 142 143 struct ConvertToSubgroupBallot : public RewritePattern { 144 ConvertToSubgroupBallot(MLIRContext *context); 145 LogicalResult matchAndRewrite(Operation *op, 146 PatternRewriter &rewriter) const override; 147 }; 148 } // namespace 149 150 void ConvertToTargetEnv::runOnOperation() { 151 MLIRContext *context = &getContext(); 152 func::FuncOp fn = getOperation(); 153 154 auto targetEnv = fn.getOperation() 155 ->getAttr(spirv::getTargetEnvAttrName()) 156 .cast<spirv::TargetEnvAttr>(); 157 if (!targetEnv) { 158 fn.emitError("missing 'spv.target_env' attribute"); 159 return signalPassFailure(); 160 } 161 162 auto target = SPIRVConversionTarget::get(targetEnv); 163 164 RewritePatternSet patterns(context); 165 patterns.add<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse, 166 ConvertToGroupNonUniformBallot, ConvertToModule, 167 ConvertToSubgroupBallot>(context); 168 169 if (failed(applyPartialConversion(fn, *target, std::move(patterns)))) 170 return signalPassFailure(); 171 } 172 173 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context) 174 : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 1, 175 context, {"spv.AtomicCompareExchangeWeak"}) {} 176 177 LogicalResult 178 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op, 179 PatternRewriter &rewriter) const { 180 Value ptr = op->getOperand(0); 181 Value value = op->getOperand(1); 182 Value comparator = op->getOperand(2); 183 184 // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in 185 // memory semantics to additionally require AtomicStorage capability. 186 rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>( 187 op, value.getType(), ptr, spirv::Scope::Workgroup, 188 spirv::MemorySemantics::AcquireRelease | 189 spirv::MemorySemantics::AtomicCounterMemory, 190 spirv::MemorySemantics::Acquire, value, comparator); 191 return success(); 192 } 193 194 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context) 195 : RewritePattern("test.convert_to_bit_reverse_op", 1, context, 196 {"spv.BitReverse"}) {} 197 198 LogicalResult 199 ConvertToBitReverse::matchAndRewrite(Operation *op, 200 PatternRewriter &rewriter) const { 201 Value predicate = op->getOperand(0); 202 203 rewriter.replaceOpWithNewOp<spirv::BitReverseOp>( 204 op, op->getResult(0).getType(), predicate); 205 return success(); 206 } 207 208 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot( 209 MLIRContext *context) 210 : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 1, context, 211 {"spv.GroupNonUniformBallot"}) {} 212 213 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite( 214 Operation *op, PatternRewriter &rewriter) const { 215 Value predicate = op->getOperand(0); 216 217 rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>( 218 op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate); 219 return success(); 220 } 221 222 ConvertToModule::ConvertToModule(MLIRContext *context) 223 : RewritePattern("test.convert_to_module_op", 1, context, {"spv.module"}) {} 224 225 LogicalResult 226 ConvertToModule::matchAndRewrite(Operation *op, 227 PatternRewriter &rewriter) const { 228 rewriter.replaceOpWithNewOp<spirv::ModuleOp>( 229 op, spirv::AddressingModel::PhysicalStorageBuffer64, 230 spirv::MemoryModel::Vulkan); 231 return success(); 232 } 233 234 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context) 235 : RewritePattern("test.convert_to_subgroup_ballot_op", 1, context, 236 {"spv.SubgroupBallotKHR"}) {} 237 238 LogicalResult 239 ConvertToSubgroupBallot::matchAndRewrite(Operation *op, 240 PatternRewriter &rewriter) const { 241 Value predicate = op->getOperand(0); 242 243 rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>( 244 op, op->getResult(0).getType(), predicate); 245 return success(); 246 } 247 248 namespace mlir { 249 void registerConvertToTargetEnvPass() { 250 PassRegistration<ConvertToTargetEnv>(); 251 } 252 } // namespace mlir 253