1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Math dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "../SPIRVCommon/Pattern.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "llvm/Support/Debug.h" 21 22 #define DEBUG_TYPE "math-to-spirv-pattern" 23 24 using namespace mlir; 25 26 //===----------------------------------------------------------------------===// 27 // Utility functions 28 //===----------------------------------------------------------------------===// 29 30 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the 31 /// given type is not a 32-bit scalar/vector type. 32 static Value getScalarOrVectorI32Constant(Type type, int value, 33 OpBuilder &builder, Location loc) { 34 if (auto vectorType = type.dyn_cast<VectorType>()) { 35 if (!vectorType.getElementType().isInteger(32)) 36 return nullptr; 37 SmallVector<int> values(vectorType.getNumElements(), value); 38 return builder.create<spirv::ConstantOp>(loc, type, 39 builder.getI32VectorAttr(values)); 40 } 41 if (type.isInteger(32)) 42 return builder.create<spirv::ConstantOp>(loc, type, 43 builder.getI32IntegerAttr(value)); 44 45 return nullptr; 46 } 47 48 //===----------------------------------------------------------------------===// 49 // Operation conversion 50 //===----------------------------------------------------------------------===// 51 52 // Note that DRR cannot be used for the patterns in this file: we may need to 53 // convert type along the way, which requires ConversionPattern. DRR generates 54 // normal RewritePattern. 55 56 namespace { 57 /// Converts math.copysign to SPIR-V ops. 58 class CopySignPattern final : public OpConversionPattern<math::CopySignOp> { 59 using OpConversionPattern::OpConversionPattern; 60 61 LogicalResult 62 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor, 63 ConversionPatternRewriter &rewriter) const override { 64 auto type = getTypeConverter()->convertType(copySignOp.getType()); 65 if (!type) 66 return failure(); 67 68 FloatType floatType; 69 if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) { 70 floatType = scalarType; 71 } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { 72 floatType = vectorType.getElementType().cast<FloatType>(); 73 } else { 74 return failure(); 75 } 76 77 Location loc = copySignOp.getLoc(); 78 int bitwidth = floatType.getWidth(); 79 Type intType = rewriter.getIntegerType(bitwidth); 80 uint64_t intValue = uint64_t(1) << (bitwidth - 1); 81 82 Value signMask = rewriter.create<spirv::ConstantOp>( 83 loc, intType, rewriter.getIntegerAttr(intType, intValue)); 84 Value valueMask = rewriter.create<spirv::ConstantOp>( 85 loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); 86 87 if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { 88 assert(vectorType.getRank() == 1); 89 int count = vectorType.getNumElements(); 90 intType = VectorType::get(count, intType); 91 92 SmallVector<Value> signSplat(count, signMask); 93 signMask = 94 rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat); 95 96 SmallVector<Value> valueSplat(count, valueMask); 97 valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType, 98 valueSplat); 99 } 100 101 Value lhsCast = 102 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs()); 103 Value rhsCast = 104 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs()); 105 106 Value value = rewriter.create<spirv::BitwiseAndOp>( 107 loc, intType, ValueRange{lhsCast, valueMask}); 108 Value sign = rewriter.create<spirv::BitwiseAndOp>( 109 loc, intType, ValueRange{rhsCast, signMask}); 110 111 Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType, 112 ValueRange{value, sign}); 113 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result); 114 return success(); 115 } 116 }; 117 118 /// Converts math.ctlz to SPIR-V ops. 119 /// 120 /// SPIR-V does not have a direct operations for counting leading zeros. If 121 /// Shader capability is supported, we can leverage GLSL FindUMsb to calculate 122 /// it. 123 class CountLeadingZerosPattern final 124 : public OpConversionPattern<math::CountLeadingZerosOp> { 125 using OpConversionPattern::OpConversionPattern; 126 127 LogicalResult 128 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor, 129 ConversionPatternRewriter &rewriter) const override { 130 auto type = getTypeConverter()->convertType(countOp.getType()); 131 if (!type) 132 return failure(); 133 134 // We can only support 32-bit integer types for now. 135 unsigned bitwidth = 0; 136 if (type.isa<IntegerType>()) 137 bitwidth = type.getIntOrFloatBitWidth(); 138 if (auto vectorType = type.dyn_cast<VectorType>()) 139 bitwidth = vectorType.getElementTypeBitWidth(); 140 if (bitwidth != 32) 141 return failure(); 142 143 Location loc = countOp.getLoc(); 144 Value input = adaptor.getOperand(); 145 Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc); 146 Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); 147 Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); 148 149 Value msb = rewriter.create<spirv::GLSLFindUMsbOp>(loc, input); 150 // We need to subtract from 31 given that the index returned by GLSL 151 // FindUMsb is counted from the least significant bit. Theoretically this 152 // also gives the correct result even if the integer has all zero bits, in 153 // which case GLSL FindUMsb would return -1. 154 Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb); 155 // However, certain Vulkan implementations have driver bugs for the corner 156 // case where the input is zero. And.. it can be smart to optimize a select 157 // only involving the corner case. So separately compute the result when the 158 // input is either zero or one. 159 Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input); 160 Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1); 161 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput, 162 subMsb); 163 return success(); 164 } 165 }; 166 167 /// Converts math.expm1 to SPIR-V ops. 168 /// 169 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to 170 /// these operations. 171 template <typename ExpOp> 172 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> { 173 using OpConversionPattern::OpConversionPattern; 174 175 LogicalResult 176 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, 177 ConversionPatternRewriter &rewriter) const override { 178 assert(adaptor.getOperands().size() == 1); 179 Location loc = operation.getLoc(); 180 auto type = this->getTypeConverter()->convertType(operation.getType()); 181 auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand()); 182 auto one = spirv::ConstantOp::getOne(type, loc, rewriter); 183 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one); 184 return success(); 185 } 186 }; 187 188 /// Converts math.log1p to SPIR-V ops. 189 /// 190 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to 191 /// these operations. 192 template <typename LogOp> 193 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { 194 using OpConversionPattern::OpConversionPattern; 195 196 LogicalResult 197 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, 198 ConversionPatternRewriter &rewriter) const override { 199 assert(adaptor.getOperands().size() == 1); 200 Location loc = operation.getLoc(); 201 auto type = this->getTypeConverter()->convertType(operation.getType()); 202 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); 203 auto onePlus = 204 rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand()); 205 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus); 206 return success(); 207 } 208 }; 209 210 /// Converts math.powf to SPIRV-Ops. 211 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { 212 using OpConversionPattern::OpConversionPattern; 213 214 LogicalResult 215 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor, 216 ConversionPatternRewriter &rewriter) const override { 217 auto dstType = getTypeConverter()->convertType(powfOp.getType()); 218 if (!dstType) 219 return failure(); 220 221 // Per GLSL Pow extended instruction spec: 222 // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0." 223 Location loc = powfOp.getLoc(); 224 Value zero = 225 spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter); 226 Value lessThan = 227 rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero); 228 Value abs = rewriter.create<spirv::GLSLFAbsOp>(loc, adaptor.getLhs()); 229 Value pow = rewriter.create<spirv::GLSLPowOp>(loc, abs, adaptor.getRhs()); 230 Value negate = rewriter.create<spirv::FNegateOp>(loc, pow); 231 rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, lessThan, negate, pow); 232 return success(); 233 } 234 }; 235 236 } // namespace 237 238 //===----------------------------------------------------------------------===// 239 // Pattern population 240 //===----------------------------------------------------------------------===// 241 242 namespace mlir { 243 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 244 RewritePatternSet &patterns) { 245 // Core patterns 246 patterns.add<CopySignPattern>(typeConverter, patterns.getContext()); 247 248 // GLSL patterns 249 patterns 250 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>, 251 ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern, 252 spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>, 253 spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>, 254 spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>, 255 spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>, 256 spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>, 257 spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>, 258 spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>, 259 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>, 260 spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>, 261 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>, 262 spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>( 263 typeConverter, patterns.getContext()); 264 265 // OpenCL patterns 266 patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>, 267 spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>, 268 spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>, 269 spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>, 270 spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>, 271 spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>, 272 spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>, 273 spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>, 274 spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>, 275 spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>, 276 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>, 277 spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>, 278 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>, 279 spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>( 280 typeConverter, patterns.getContext()); 281 } 282 283 } // namespace mlir 284