1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Math dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "../SPIRVCommon/Pattern.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "llvm/Support/Debug.h" 21 22 #define DEBUG_TYPE "math-to-spirv-pattern" 23 24 using namespace mlir; 25 26 //===----------------------------------------------------------------------===// 27 // Utility functions 28 //===----------------------------------------------------------------------===// 29 30 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the 31 /// given type is not a 32-bit scalar/vector type. 32 static Value getScalarOrVectorI32Constant(Type type, int value, 33 OpBuilder &builder, Location loc) { 34 if (auto vectorType = type.dyn_cast<VectorType>()) { 35 if (!vectorType.getElementType().isInteger(32)) 36 return nullptr; 37 SmallVector<int> values(vectorType.getNumElements(), value); 38 return builder.create<spirv::ConstantOp>(loc, type, 39 builder.getI32VectorAttr(values)); 40 } 41 if (type.isInteger(32)) 42 return builder.create<spirv::ConstantOp>(loc, type, 43 builder.getI32IntegerAttr(value)); 44 45 return nullptr; 46 } 47 48 //===----------------------------------------------------------------------===// 49 // Operation conversion 50 //===----------------------------------------------------------------------===// 51 52 // Note that DRR cannot be used for the patterns in this file: we may need to 53 // convert type along the way, which requires ConversionPattern. DRR generates 54 // normal RewritePattern. 55 56 namespace { 57 /// Converts math.copysign to SPIR-V ops. 58 class CopySignPattern final : public OpConversionPattern<math::CopySignOp> { 59 using OpConversionPattern::OpConversionPattern; 60 61 LogicalResult 62 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor, 63 ConversionPatternRewriter &rewriter) const override { 64 auto type = getTypeConverter()->convertType(copySignOp.getType()); 65 if (!type) 66 return failure(); 67 68 FloatType floatType; 69 if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) { 70 floatType = scalarType; 71 } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { 72 floatType = vectorType.getElementType().cast<FloatType>(); 73 } else { 74 return failure(); 75 } 76 77 Location loc = copySignOp.getLoc(); 78 int bitwidth = floatType.getWidth(); 79 Type intType = rewriter.getIntegerType(bitwidth); 80 uint64_t intValue = uint64_t(1) << (bitwidth - 1); 81 82 Value signMask = rewriter.create<spirv::ConstantOp>( 83 loc, intType, rewriter.getIntegerAttr(intType, intValue)); 84 Value valueMask = rewriter.create<spirv::ConstantOp>( 85 loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); 86 87 if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { 88 assert(vectorType.getRank() == 1); 89 int count = vectorType.getNumElements(); 90 intType = VectorType::get(count, intType); 91 92 SmallVector<Value> signSplat(count, signMask); 93 signMask = 94 rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat); 95 96 SmallVector<Value> valueSplat(count, valueMask); 97 valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType, 98 valueSplat); 99 } 100 101 Value lhsCast = 102 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs()); 103 Value rhsCast = 104 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs()); 105 106 Value value = rewriter.create<spirv::BitwiseAndOp>( 107 loc, intType, ValueRange{lhsCast, valueMask}); 108 Value sign = rewriter.create<spirv::BitwiseAndOp>( 109 loc, intType, ValueRange{rhsCast, signMask}); 110 111 Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType, 112 ValueRange{value, sign}); 113 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result); 114 return success(); 115 } 116 }; 117 118 /// Converts math.ctlz to SPIR-V ops. 119 /// 120 /// SPIR-V does not have a direct operations for counting leading zeros. If 121 /// Shader capability is supported, we can leverage GLSL FindUMsb to calculate 122 /// it. 123 class CountLeadingZerosPattern final 124 : public OpConversionPattern<math::CountLeadingZerosOp> { 125 using OpConversionPattern::OpConversionPattern; 126 127 LogicalResult 128 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor, 129 ConversionPatternRewriter &rewriter) const override { 130 auto type = getTypeConverter()->convertType(countOp.getType()); 131 if (!type) 132 return failure(); 133 134 // We can only support 32-bit integer types for now. 135 unsigned bitwidth = 0; 136 if (type.isa<IntegerType>()) 137 bitwidth = type.getIntOrFloatBitWidth(); 138 if (auto vectorType = type.dyn_cast<VectorType>()) 139 bitwidth = vectorType.getElementTypeBitWidth(); 140 if (bitwidth != 32) 141 return failure(); 142 143 Location loc = countOp.getLoc(); 144 Value allOneBits = getScalarOrVectorI32Constant(type, -1, rewriter, loc); 145 Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); 146 Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); 147 Value msb = 148 rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand()); 149 // We need to subtract from 31 given that the index is from the least 150 // significant bit. 151 Value sub = rewriter.create<spirv::ISubOp>(loc, val31, msb); 152 // If the integer has all zero bits, GLSL FindUMsb would return -1. So 153 // theoretically (31 - FindUMsb) should still give the correct result. 154 // However, certain Vulkan implementations have driver bugs regarding it. 155 // So handle the corner case explicity to workaround it. 156 Value cmp = rewriter.create<spirv::IEqualOp>(loc, msb, allOneBits); 157 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, val32, sub); 158 return success(); 159 } 160 }; 161 162 /// Converts math.expm1 to SPIR-V ops. 163 /// 164 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to 165 /// these operations. 166 template <typename ExpOp> 167 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> { 168 using OpConversionPattern::OpConversionPattern; 169 170 LogicalResult 171 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, 172 ConversionPatternRewriter &rewriter) const override { 173 assert(adaptor.getOperands().size() == 1); 174 Location loc = operation.getLoc(); 175 auto type = this->getTypeConverter()->convertType(operation.getType()); 176 auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand()); 177 auto one = spirv::ConstantOp::getOne(type, loc, rewriter); 178 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one); 179 return success(); 180 } 181 }; 182 183 /// Converts math.log1p to SPIR-V ops. 184 /// 185 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to 186 /// these operations. 187 template <typename LogOp> 188 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { 189 using OpConversionPattern::OpConversionPattern; 190 191 LogicalResult 192 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, 193 ConversionPatternRewriter &rewriter) const override { 194 assert(adaptor.getOperands().size() == 1); 195 Location loc = operation.getLoc(); 196 auto type = this->getTypeConverter()->convertType(operation.getType()); 197 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); 198 auto onePlus = 199 rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand()); 200 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus); 201 return success(); 202 } 203 }; 204 205 /// Converts math.powf to SPIRV-Ops. 206 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { 207 using OpConversionPattern::OpConversionPattern; 208 209 LogicalResult 210 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor, 211 ConversionPatternRewriter &rewriter) const override { 212 auto dstType = getTypeConverter()->convertType(powfOp.getType()); 213 if (!dstType) 214 return failure(); 215 216 // Per GLSL Pow extended instruction spec: 217 // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0." 218 Location loc = powfOp.getLoc(); 219 Value zero = 220 spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter); 221 Value lessThan = 222 rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero); 223 Value abs = rewriter.create<spirv::GLSLFAbsOp>(loc, adaptor.getLhs()); 224 Value pow = rewriter.create<spirv::GLSLPowOp>(loc, abs, adaptor.getRhs()); 225 Value negate = rewriter.create<spirv::FNegateOp>(loc, pow); 226 rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, lessThan, negate, pow); 227 return success(); 228 } 229 }; 230 231 } // namespace 232 233 //===----------------------------------------------------------------------===// 234 // Pattern population 235 //===----------------------------------------------------------------------===// 236 237 namespace mlir { 238 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 239 RewritePatternSet &patterns) { 240 // Core patterns 241 patterns.add<CopySignPattern>(typeConverter, patterns.getContext()); 242 243 // GLSL patterns 244 patterns 245 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>, 246 ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern, 247 spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>, 248 spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>, 249 spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>, 250 spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>, 251 spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>, 252 spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>, 253 spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>, 254 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>, 255 spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>, 256 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>, 257 spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>( 258 typeConverter, patterns.getContext()); 259 260 // OpenCL patterns 261 patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>, 262 spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>, 263 spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>, 264 spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>, 265 spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>, 266 spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>, 267 spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>, 268 spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>, 269 spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>, 270 spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>, 271 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>, 272 spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>, 273 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>, 274 spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>( 275 typeConverter, patterns.getContext()); 276 } 277 278 } // namespace mlir 279