1 //===- TestAvailability.cpp - Pass to test SPIR-V op availability ---------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/Dialect/SPIRV/SPIRVLowering.h" 10 #include "mlir/Dialect/SPIRV/SPIRVOps.h" 11 #include "mlir/Dialect/SPIRV/SPIRVTypes.h" 12 #include "mlir/IR/Function.h" 13 #include "mlir/Pass/Pass.h" 14 15 using namespace mlir; 16 17 //===----------------------------------------------------------------------===// 18 // Printing op availability pass 19 //===----------------------------------------------------------------------===// 20 21 namespace { 22 /// A pass for testing SPIR-V op availability. 23 struct PrintOpAvailability 24 : public PassWrapper<PrintOpAvailability, FunctionPass> { 25 void runOnFunction() override; 26 }; 27 } // end anonymous namespace 28 29 void PrintOpAvailability::runOnFunction() { 30 auto f = getFunction(); 31 llvm::outs() << f.getName() << "\n"; 32 33 Dialect *spvDialect = getContext().getRegisteredDialect("spv"); 34 35 f.getOperation()->walk([&](Operation *op) { 36 if (op->getDialect() != spvDialect) 37 return WalkResult::advance(); 38 39 auto opName = op->getName(); 40 auto &os = llvm::outs(); 41 42 if (auto minVersion = dyn_cast<spirv::QueryMinVersionInterface>(op)) 43 os << opName << " min version: " 44 << spirv::stringifyVersion(minVersion.getMinVersion()) << "\n"; 45 46 if (auto maxVersion = dyn_cast<spirv::QueryMaxVersionInterface>(op)) 47 os << opName << " max version: " 48 << spirv::stringifyVersion(maxVersion.getMaxVersion()) << "\n"; 49 50 if (auto extension = dyn_cast<spirv::QueryExtensionInterface>(op)) { 51 os << opName << " extensions: ["; 52 for (const auto &exts : extension.getExtensions()) { 53 os << " ["; 54 llvm::interleaveComma(exts, os, [&](spirv::Extension ext) { 55 os << spirv::stringifyExtension(ext); 56 }); 57 os << "]"; 58 } 59 os << " ]\n"; 60 } 61 62 if (auto capability = dyn_cast<spirv::QueryCapabilityInterface>(op)) { 63 os << opName << " capabilities: ["; 64 for (const auto &caps : capability.getCapabilities()) { 65 os << " ["; 66 llvm::interleaveComma(caps, os, [&](spirv::Capability cap) { 67 os << spirv::stringifyCapability(cap); 68 }); 69 os << "]"; 70 } 71 os << " ]\n"; 72 } 73 os.flush(); 74 75 return WalkResult::advance(); 76 }); 77 } 78 79 namespace mlir { 80 void registerPrintOpAvailabilityPass() { 81 PassRegistration<PrintOpAvailability> printOpAvailabilityPass( 82 "test-spirv-op-availability", "Test SPIR-V op availability"); 83 } 84 } // namespace mlir 85 86 //===----------------------------------------------------------------------===// 87 // Converting target environment pass 88 //===----------------------------------------------------------------------===// 89 90 namespace { 91 /// A pass for testing SPIR-V op availability. 92 struct ConvertToTargetEnv 93 : public PassWrapper<ConvertToTargetEnv, FunctionPass> { 94 void runOnFunction() override; 95 }; 96 97 struct ConvertToAtomCmpExchangeWeak : public RewritePattern { 98 ConvertToAtomCmpExchangeWeak(MLIRContext *context); 99 LogicalResult matchAndRewrite(Operation *op, 100 PatternRewriter &rewriter) const override; 101 }; 102 103 struct ConvertToBitReverse : public RewritePattern { 104 ConvertToBitReverse(MLIRContext *context); 105 LogicalResult matchAndRewrite(Operation *op, 106 PatternRewriter &rewriter) const override; 107 }; 108 109 struct ConvertToGroupNonUniformBallot : public RewritePattern { 110 ConvertToGroupNonUniformBallot(MLIRContext *context); 111 LogicalResult matchAndRewrite(Operation *op, 112 PatternRewriter &rewriter) const override; 113 }; 114 115 struct ConvertToModule : public RewritePattern { 116 ConvertToModule(MLIRContext *context); 117 LogicalResult matchAndRewrite(Operation *op, 118 PatternRewriter &rewriter) const override; 119 }; 120 121 struct ConvertToSubgroupBallot : public RewritePattern { 122 ConvertToSubgroupBallot(MLIRContext *context); 123 LogicalResult matchAndRewrite(Operation *op, 124 PatternRewriter &rewriter) const override; 125 }; 126 } // end anonymous namespace 127 128 void ConvertToTargetEnv::runOnFunction() { 129 MLIRContext *context = &getContext(); 130 FuncOp fn = getFunction(); 131 132 auto targetEnv = fn.getOperation() 133 ->getAttr(spirv::getTargetEnvAttrName()) 134 .cast<spirv::TargetEnvAttr>(); 135 if (!targetEnv) { 136 fn.emitError("missing 'spv.target_env' attribute"); 137 return signalPassFailure(); 138 } 139 140 auto target = spirv::SPIRVConversionTarget::get(targetEnv); 141 142 OwningRewritePatternList patterns; 143 patterns.insert<ConvertToAtomCmpExchangeWeak, ConvertToBitReverse, 144 ConvertToGroupNonUniformBallot, ConvertToModule, 145 ConvertToSubgroupBallot>(context); 146 147 if (failed(applyPartialConversion(fn, *target, patterns))) 148 return signalPassFailure(); 149 } 150 151 ConvertToAtomCmpExchangeWeak::ConvertToAtomCmpExchangeWeak(MLIRContext *context) 152 : RewritePattern("test.convert_to_atomic_compare_exchange_weak_op", 153 {"spv.AtomicCompareExchangeWeak"}, 1, context) {} 154 155 LogicalResult 156 ConvertToAtomCmpExchangeWeak::matchAndRewrite(Operation *op, 157 PatternRewriter &rewriter) const { 158 Value ptr = op->getOperand(0); 159 Value value = op->getOperand(1); 160 Value comparator = op->getOperand(2); 161 162 // Create a spv.AtomicCompareExchangeWeak op with AtomicCounterMemory bits in 163 // memory semantics to additionally require AtomicStorage capability. 164 rewriter.replaceOpWithNewOp<spirv::AtomicCompareExchangeWeakOp>( 165 op, value.getType(), ptr, spirv::Scope::Workgroup, 166 spirv::MemorySemantics::AcquireRelease | 167 spirv::MemorySemantics::AtomicCounterMemory, 168 spirv::MemorySemantics::Acquire, value, comparator); 169 return success(); 170 } 171 172 ConvertToBitReverse::ConvertToBitReverse(MLIRContext *context) 173 : RewritePattern("test.convert_to_bit_reverse_op", {"spv.BitReverse"}, 1, 174 context) {} 175 176 LogicalResult 177 ConvertToBitReverse::matchAndRewrite(Operation *op, 178 PatternRewriter &rewriter) const { 179 Value predicate = op->getOperand(0); 180 181 rewriter.replaceOpWithNewOp<spirv::BitReverseOp>( 182 op, op->getResult(0).getType(), predicate); 183 return success(); 184 } 185 186 ConvertToGroupNonUniformBallot::ConvertToGroupNonUniformBallot( 187 MLIRContext *context) 188 : RewritePattern("test.convert_to_group_non_uniform_ballot_op", 189 {"spv.GroupNonUniformBallot"}, 1, context) {} 190 191 LogicalResult ConvertToGroupNonUniformBallot::matchAndRewrite( 192 Operation *op, PatternRewriter &rewriter) const { 193 Value predicate = op->getOperand(0); 194 195 rewriter.replaceOpWithNewOp<spirv::GroupNonUniformBallotOp>( 196 op, op->getResult(0).getType(), spirv::Scope::Workgroup, predicate); 197 return success(); 198 } 199 200 ConvertToModule::ConvertToModule(MLIRContext *context) 201 : RewritePattern("test.convert_to_module_op", {"spv.module"}, 1, context) {} 202 203 LogicalResult 204 ConvertToModule::matchAndRewrite(Operation *op, 205 PatternRewriter &rewriter) const { 206 rewriter.replaceOpWithNewOp<spirv::ModuleOp>( 207 op, spirv::AddressingModel::PhysicalStorageBuffer64, 208 spirv::MemoryModel::Vulkan); 209 return success(); 210 } 211 212 ConvertToSubgroupBallot::ConvertToSubgroupBallot(MLIRContext *context) 213 : RewritePattern("test.convert_to_subgroup_ballot_op", 214 {"spv.SubgroupBallotKHR"}, 1, context) {} 215 216 LogicalResult 217 ConvertToSubgroupBallot::matchAndRewrite(Operation *op, 218 PatternRewriter &rewriter) const { 219 Value predicate = op->getOperand(0); 220 221 rewriter.replaceOpWithNewOp<spirv::SubgroupBallotKHROp>( 222 op, op->getResult(0).getType(), predicate); 223 return success(); 224 } 225 226 namespace mlir { 227 void registerConvertToTargetEnvPass() { 228 PassRegistration<ConvertToTargetEnv> convertToTargetEnvPass( 229 "test-spirv-target-env", "Test SPIR-V target environment"); 230 } 231 } // namespace mlir 232