1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SPIRV/SPIRVLowering.h" 10 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 11 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 12 #include "mlir/IR/Function.h" 13 #include "mlir/Pass/Pass.h" 14 15 using namespace mlir; 16 17 //===----------------------------------------------------------------------===// 18 // Printing op availability pass 19 //===----------------------------------------------------------------------===// 20 21 namespace { 22 /// A pass for testing SPIR-V op availability. 23 struct PrintOpAvailability : public FunctionPass<PrintOpAvailability> { 24 void runOnFunction() override; 25 }; 26 } // end anonymous namespace 27 28 void PrintOpAvailability::runOnFunction() { 29 auto f = getFunction(); 30 llvm::outs() << f.getName() << "\n"; 31 32 Dialect *spvDialect = getContext().getRegisteredDialect("spv"); 33 34 f.getOperation()->walk([&](Operation *op) { 35 if (op->getDialect() != spvDialect) 36 return WalkResult::advance(); 37 38 auto opName = op->getName(); 39 auto &os = llvm::outs(); 40 41 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) 42 os << opName << " min version: " 43 << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n"; 44 45 if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op)) 46 os << opName << " max version: " 47 << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n"; 48 49 if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) { 50 os << opName << " extensions: ["; 51 for (const auto &exts : extension.getExtensions()) { 52 os << " ["; 53 interleaveComma(exts, os, [&](spirv::Extension ext) { 54 os << spirv::stringifyExtension(ext); 55 }); 56 os << "]"; 57 } 58 os << " ]\n"; 59 } 60 61 if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) { 62 os << opName << " capabilities: ["; 63 for (const auto &caps : capability.getCapabilities()) { 64 os << " ["; 65 interleaveComma(caps, os, [&](spirv::Capability cap) { 66 os << spirv::stringifyCapability(cap); 67 }); 68 os << "]"; 69 } 70 os << " ]\n"; 71 } 72 os.flush(); 73 74 return WalkResult::advance(); 75 }); 76 } 77 78 namespace mlir { 79 void registerPrintOpAvailabilityPass() { 80 PassRegistration<PrintOpAvailability> printOpAvailabilityPass( 81 "test-spirv-op-availability", "Test SPIR-V op availability"); 82 } 83 } // namespace mlir 84 85 //===----------------------------------------------------------------------===// 86 // Converting target environment pass 87 //===----------------------------------------------------------------------===// 88 89 namespace { 90 /// A pass for testing SPIR-V op availability. 91 struct ConvertToTargetEnv : public FunctionPass<ConvertToTargetEnv> { 92 void runOnFunction() override; 93 }; 94 95 struct ConvertToAtomCmpExchangeWeak : public RewritePattern { 96 ConvertToAtomCmpExchangeWeak(MLIRContext *context); 97 LogicalResult matchAndRewrite(Operation *op, 98 PatternRewriter &rewriter) const override; 99 }; 100 101 struct ConvertToBitReverse : public RewritePattern { 102 ConvertToBitReverse(MLIRContext *context); 103 LogicalResult matchAndRewrite(Operation *op, 104 PatternRewriter &rewriter) const override; 105 }; 106 107 struct ConvertToGroupNonUniformBallot : public RewritePattern { 108 ConvertToGroupNonUniformBallot(MLIRContext *context); 109 LogicalResult matchAndRewrite(Operation *op, 110 PatternRewriter &rewriter) const override; 111 }; 112 113 struct ConvertToModule : public RewritePattern { 114 ConvertToModule(MLIRContext *context); 115 LogicalResult matchAndRewrite(Operation *op, 116 PatternRewriter &rewriter) const override; 117 }; 118 119 struct ConvertToSubgroupBallot : public RewritePattern { 120 ConvertToSubgroupBallot(MLIRContext *context); 121 LogicalResult matchAndRewrite(Operation *op, 122 PatternRewriter &rewriter) const override; 123 }; 124 } // end anonymous namespace 125 126 void ConvertToTargetEnv::runOnFunction() { 127 MLIRContext *context = &getContext(); 128 FuncOp fn = getFunction(); 129 130 auto targetEnv = fn.getOperation() 131 ->getAttr(spirv::getTargetEnvAttrName()) 132 .cast<spirv::TargetEnvAttr>(); 133 if (!targetEnv) { 134 fn.emitError("missing 'spv.target_env' attribute"); 135 return signalPassFailure(); 136 } 137 138 auto target = spirv::SPIRVConversionTarget::get(targetEnv); 139 140 OwningRewritePatternList patterns; 141 patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse, 142 ConvertToGroupNonUniformBallot, ConvertToModule, 143 ConvertToSubgroupBallot>(context); 144 145 if (failed(applyPartialConversion(fn, *target, patterns))) 146 return signalPassFailure(); 147 } 148 149 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context) 150 : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 151 {"spv.AtomicCompareExchangeWeak"}, 1, context) {} 152 153 LogicalResult 154 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op, 155 PatternRewriter &rewriter) const { 156 Value ptr = op->getOperand(0); 157 Value value = op->getOperand(1); 158 Value comparator = op->getOperand(2); 159 160 // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in 161 // memory semantics to additionally require AtomicStorage capability. 162 rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>( 163 op, value.getType(), ptr, spirv::Scope::Workgroup, 164 spirv::MemorySemantics::AcquireRelease | 165 spirv::MemorySemantics::AtomicCounterMemory, 166 spirv::MemorySemantics::Acquire, value, comparator); 167 return success(); 168 } 169 170 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context) 171 : RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1, 172 context) {} 173 174 LogicalResult 175 ConvertToBitReverse::matchAndRewrite(Operation *op, 176 PatternRewriter &rewriter) const { 177 Value predicate = op->getOperand(0); 178 179 rewriter.replaceOpWithNewOp<spirv::BitReverseOp>( 180 op, op->getResult(0).getType(), predicate); 181 return success(); 182 } 183 184 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot( 185 MLIRContext *context) 186 : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 187 {"spv.GroupNonUniformBallot"}, 1, context) {} 188 189 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite( 190 Operation *op, PatternRewriter &rewriter) const { 191 Value predicate = op->getOperand(0); 192 193 rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>( 194 op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate); 195 return success(); 196 } 197 198 ConvertToModule::ConvertToModule(MLIRContext *context) 199 : RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {} 200 201 LogicalResult 202 ConvertToModule::matchAndRewrite(Operation *op, 203 PatternRewriter &rewriter) const { 204 rewriter.replaceOpWithNewOp<spirv::ModuleOp>( 205 op, spirv::AddressingModel::PhysicalStorageBuffer64, 206 spirv::MemoryModel::Vulkan); 207 return success(); 208 } 209 210 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context) 211 : RewritePattern("test.convert_to_subgroup_ballot_op", 212 {"spv.SubgroupBallotKHR"}, 1, context) {} 213 214 LogicalResult 215 ConvertToSubgroupBallot::matchAndRewrite(Operation *op, 216 PatternRewriter &rewriter) const { 217 Value predicate = op->getOperand(0); 218 219 rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>( 220 op, op->getResult(0).getType(), predicate); 221 return success(); 222 } 223 224 namespace mlir { 225 void registerConvertToTargetEnvPass() { 226 PassRegistration<ConvertToTargetEnv> convertToTargetEnvPass( 227 "test-spirv-target-env", "Test SPIR-V target environment"); 228 } 229 } // namespace mlir 230