1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Math dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "../SPIRVCommon/Pattern.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/IR/TypeUtilities.h" 20 #include "mlir/Transforms/DialectConversion.h" 21 #include "llvm/Support/Debug.h" 22 23 #define DEBUG_TYPE "math-to-spirv-pattern" 24 25 using namespace mlir; 26 27 //===----------------------------------------------------------------------===// 28 // Utility functions 29 //===----------------------------------------------------------------------===// 30 31 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the 32 /// given type is not a 32-bit scalar/vector type. 33 static Value getScalarOrVectorI32Constant(Type type, int value, 34 OpBuilder &builder, Location loc) { 35 if (auto vectorType = type.dyn_cast<VectorType>()) { 36 if (!vectorType.getElementType().isInteger(32)) 37 return nullptr; 38 SmallVector<int> values(vectorType.getNumElements(), value); 39 return builder.create<spirv::ConstantOp>(loc, type, 40 builder.getI32VectorAttr(values)); 41 } 42 if (type.isInteger(32)) 43 return builder.create<spirv::ConstantOp>(loc, type, 44 builder.getI32IntegerAttr(value)); 45 46 return nullptr; 47 } 48 49 //===----------------------------------------------------------------------===// 50 // Operation conversion 51 //===----------------------------------------------------------------------===// 52 53 // Note that DRR cannot be used for the patterns in this file: we may need to 54 // convert type along the way, which requires ConversionPattern. DRR generates 55 // normal RewritePattern. 56 57 namespace { 58 /// Converts math.copysign to SPIR-V ops. 59 class CopySignPattern final : public OpConversionPattern<math::CopySignOp> { 60 using OpConversionPattern::OpConversionPattern; 61 62 LogicalResult 63 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor, 64 ConversionPatternRewriter &rewriter) const override { 65 auto type = getTypeConverter()->convertType(copySignOp.getType()); 66 if (!type) 67 return failure(); 68 69 FloatType floatType; 70 if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) { 71 floatType = scalarType; 72 } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { 73 floatType = vectorType.getElementType().cast<FloatType>(); 74 } else { 75 return failure(); 76 } 77 78 Location loc = copySignOp.getLoc(); 79 int bitwidth = floatType.getWidth(); 80 Type intType = rewriter.getIntegerType(bitwidth); 81 uint64_t intValue = uint64_t(1) << (bitwidth - 1); 82 83 Value signMask = rewriter.create<spirv::ConstantOp>( 84 loc, intType, rewriter.getIntegerAttr(intType, intValue)); 85 Value valueMask = rewriter.create<spirv::ConstantOp>( 86 loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); 87 88 if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { 89 assert(vectorType.getRank() == 1); 90 int count = vectorType.getNumElements(); 91 intType = VectorType::get(count, intType); 92 93 SmallVector<Value> signSplat(count, signMask); 94 signMask = 95 rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat); 96 97 SmallVector<Value> valueSplat(count, valueMask); 98 valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType, 99 valueSplat); 100 } 101 102 Value lhsCast = 103 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs()); 104 Value rhsCast = 105 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs()); 106 107 Value value = rewriter.create<spirv::BitwiseAndOp>( 108 loc, intType, ValueRange{lhsCast, valueMask}); 109 Value sign = rewriter.create<spirv::BitwiseAndOp>( 110 loc, intType, ValueRange{rhsCast, signMask}); 111 112 Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType, 113 ValueRange{value, sign}); 114 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result); 115 return success(); 116 } 117 }; 118 119 /// Converts math.ctlz to SPIR-V ops. 120 /// 121 /// SPIR-V does not have a direct operations for counting leading zeros. If 122 /// Shader capability is supported, we can leverage GLSL FindUMsb to calculate 123 /// it. 124 class CountLeadingZerosPattern final 125 : public OpConversionPattern<math::CountLeadingZerosOp> { 126 using OpConversionPattern::OpConversionPattern; 127 128 LogicalResult 129 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor, 130 ConversionPatternRewriter &rewriter) const override { 131 auto type = getTypeConverter()->convertType(countOp.getType()); 132 if (!type) 133 return failure(); 134 135 // We can only support 32-bit integer types for now. 136 unsigned bitwidth = 0; 137 if (type.isa<IntegerType>()) 138 bitwidth = type.getIntOrFloatBitWidth(); 139 if (auto vectorType = type.dyn_cast<VectorType>()) 140 bitwidth = vectorType.getElementTypeBitWidth(); 141 if (bitwidth != 32) 142 return failure(); 143 144 Location loc = countOp.getLoc(); 145 Value input = adaptor.getOperand(); 146 Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc); 147 Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); 148 Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc); 149 150 Value msb = rewriter.create<spirv::GLSLFindUMsbOp>(loc, input); 151 // We need to subtract from 31 given that the index returned by GLSL 152 // FindUMsb is counted from the least significant bit. Theoretically this 153 // also gives the correct result even if the integer has all zero bits, in 154 // which case GLSL FindUMsb would return -1. 155 Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb); 156 // However, certain Vulkan implementations have driver bugs for the corner 157 // case where the input is zero. And.. it can be smart to optimize a select 158 // only involving the corner case. So separately compute the result when the 159 // input is either zero or one. 160 Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input); 161 Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1); 162 rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput, 163 subMsb); 164 return success(); 165 } 166 }; 167 168 /// Converts math.expm1 to SPIR-V ops. 169 /// 170 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to 171 /// these operations. 172 template <typename ExpOp> 173 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> { 174 using OpConversionPattern::OpConversionPattern; 175 176 LogicalResult 177 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, 178 ConversionPatternRewriter &rewriter) const override { 179 assert(adaptor.getOperands().size() == 1); 180 Location loc = operation.getLoc(); 181 auto type = this->getTypeConverter()->convertType(operation.getType()); 182 auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand()); 183 auto one = spirv::ConstantOp::getOne(type, loc, rewriter); 184 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one); 185 return success(); 186 } 187 }; 188 189 /// Converts math.log1p to SPIR-V ops. 190 /// 191 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to 192 /// these operations. 193 template <typename LogOp> 194 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { 195 using OpConversionPattern::OpConversionPattern; 196 197 LogicalResult 198 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, 199 ConversionPatternRewriter &rewriter) const override { 200 assert(adaptor.getOperands().size() == 1); 201 Location loc = operation.getLoc(); 202 auto type = this->getTypeConverter()->convertType(operation.getType()); 203 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); 204 auto onePlus = 205 rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand()); 206 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus); 207 return success(); 208 } 209 }; 210 211 /// Converts math.powf to SPIRV-Ops. 212 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> { 213 using OpConversionPattern::OpConversionPattern; 214 215 LogicalResult 216 matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor, 217 ConversionPatternRewriter &rewriter) const override { 218 auto dstType = getTypeConverter()->convertType(powfOp.getType()); 219 if (!dstType) 220 return failure(); 221 222 // Per GLSL Pow extended instruction spec: 223 // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0." 224 Location loc = powfOp.getLoc(); 225 Value zero = 226 spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter); 227 Value lessThan = 228 rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero); 229 Value abs = rewriter.create<spirv::GLSLFAbsOp>(loc, adaptor.getLhs()); 230 Value pow = rewriter.create<spirv::GLSLPowOp>(loc, abs, adaptor.getRhs()); 231 Value negate = rewriter.create<spirv::FNegateOp>(loc, pow); 232 rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, lessThan, negate, pow); 233 return success(); 234 } 235 }; 236 237 /// Converts math.round to GLSL SPIRV extended ops. 238 struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> { 239 using OpConversionPattern::OpConversionPattern; 240 241 LogicalResult 242 matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor, 243 ConversionPatternRewriter &rewriter) const override { 244 Location loc = roundOp.getLoc(); 245 auto operand = roundOp.getOperand(); 246 auto ty = operand.getType(); 247 auto ety = getElementTypeOrSelf(ty); 248 249 auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter); 250 auto one = spirv::ConstantOp::getOne(ty, loc, rewriter); 251 Value half; 252 if (VectorType vty = ty.dyn_cast<VectorType>()) { 253 half = rewriter.create<spirv::ConstantOp>( 254 loc, vty, 255 DenseElementsAttr::get(vty, 256 rewriter.getFloatAttr(ety, 0.5).getValue())); 257 } else { 258 half = rewriter.create<spirv::ConstantOp>( 259 loc, ty, rewriter.getFloatAttr(ety, 0.5)); 260 } 261 262 auto abs = rewriter.create<spirv::GLSLFAbsOp>(loc, operand); 263 auto floor = rewriter.create<spirv::GLSLFloorOp>(loc, abs); 264 auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor); 265 auto greater = 266 rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half); 267 auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero); 268 auto add = rewriter.create<spirv::FAddOp>(loc, floor, select); 269 rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand); 270 return success(); 271 } 272 }; 273 274 } // namespace 275 276 //===----------------------------------------------------------------------===// 277 // Pattern population 278 //===----------------------------------------------------------------------===// 279 280 namespace mlir { 281 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 282 RewritePatternSet &patterns) { 283 // Core patterns 284 patterns.add<CopySignPattern>(typeConverter, patterns.getContext()); 285 286 // GLSL patterns 287 patterns 288 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>, 289 ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern, RoundOpPattern, 290 spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>, 291 spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>, 292 spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>, 293 spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>, 294 spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>, 295 spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>, 296 spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>, 297 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>, 298 spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>, 299 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>, 300 spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>( 301 typeConverter, patterns.getContext()); 302 303 // OpenCL patterns 304 patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>, 305 spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>, 306 spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>, 307 spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>, 308 spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>, 309 spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>, 310 spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>, 311 spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>, 312 spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>, 313 spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>, 314 spirv::ElementwiseOpPattern<math::RoundOp, spirv::OCLRoundOp>, 315 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>, 316 spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>, 317 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>, 318 spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>( 319 typeConverter, patterns.getContext()); 320 } 321 322 } // namespace mlir 323