1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Math dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "../SPIRVCommon/Pattern.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "mlir/Transforms/DialectConversion.h" 20 #include "llvm/Support/Debug.h" 21 22 #define DEBUG_TYPE "math-to-spirv-pattern" 23 24 using namespace mlir; 25 26 //===----------------------------------------------------------------------===// 27 // Utility functions 28 //===----------------------------------------------------------------------===// 29 30 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the 31 /// given type is not a 32-bit scalar/vector type. 32 static Value getScalarOrVectorI32Constant(Type type, int value, 33 OpBuilder &builder, Location loc) { 34 if (auto vectorType = type.dyn_cast<VectorType>()) { 35 if (!vectorType.getElementType().isInteger(32)) 36 return nullptr; 37 SmallVector<int> values(vectorType.getNumElements(), value); 38 return builder.create<spirv::ConstantOp>(loc, type, 39 builder.getI32VectorAttr(values)); 40 } 41 if (type.isInteger(32)) 42 return builder.create<spirv::ConstantOp>(loc, type, 43 builder.getI32IntegerAttr(value)); 44 45 return nullptr; 46 } 47 48 //===----------------------------------------------------------------------===// 49 // Operation conversion 50 //===----------------------------------------------------------------------===// 51 52 // Note that DRR cannot be used for the patterns in this file: we may need to 53 // convert type along the way, which requires ConversionPattern. DRR generates 54 // normal RewritePattern. 55 56 namespace { 57 /// Converts math.copysign to SPIR-V ops. 58 class CopySignPattern final : public OpConversionPattern<math::CopySignOp> { 59 using OpConversionPattern::OpConversionPattern; 60 61 LogicalResult 62 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor, 63 ConversionPatternRewriter &rewriter) const override { 64 auto type = getTypeConverter()->convertType(copySignOp.getType()); 65 if (!type) 66 return failure(); 67 68 FloatType floatType; 69 if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) { 70 floatType = scalarType; 71 } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { 72 floatType = vectorType.getElementType().cast<FloatType>(); 73 } else { 74 return failure(); 75 } 76 77 Location loc = copySignOp.getLoc(); 78 int bitwidth = floatType.getWidth(); 79 Type intType = rewriter.getIntegerType(bitwidth); 80 uint64_t intValue = uint64_t(1) << (bitwidth - 1); 81 82 Value signMask = rewriter.create<spirv::ConstantOp>( 83 loc, intType, rewriter.getIntegerAttr(intType, intValue)); 84 Value valueMask = rewriter.create<spirv::ConstantOp>( 85 loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u)); 86 87 if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { 88 assert(vectorType.getRank() == 1); 89 int count = vectorType.getNumElements(); 90 intType = VectorType::get(count, intType); 91 92 SmallVector<Value> signSplat(count, signMask); 93 signMask = 94 rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat); 95 96 SmallVector<Value> valueSplat(count, valueMask); 97 valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType, 98 valueSplat); 99 } 100 101 Value lhsCast = 102 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs()); 103 Value rhsCast = 104 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs()); 105 106 Value value = rewriter.create<spirv::BitwiseAndOp>( 107 loc, intType, ValueRange{lhsCast, valueMask}); 108 Value sign = rewriter.create<spirv::BitwiseAndOp>( 109 loc, intType, ValueRange{rhsCast, signMask}); 110 111 Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType, 112 ValueRange{value, sign}); 113 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result); 114 return success(); 115 } 116 }; 117 118 /// Converts math.ctlz to SPIR-V ops. 119 /// 120 /// SPIR-V does not have a direct operations for counting leading zeros. If 121 /// Shader capability is supported, we can leverage GLSL FindUMsb to calculate 122 /// it. 123 class CountLeadingZerosPattern final 124 : public OpConversionPattern<math::CountLeadingZerosOp> { 125 using OpConversionPattern::OpConversionPattern; 126 127 LogicalResult 128 matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor, 129 ConversionPatternRewriter &rewriter) const override { 130 auto type = getTypeConverter()->convertType(countOp.getType()); 131 if (!type) 132 return failure(); 133 134 // We can only support 32-bit integer types for now. 135 unsigned bitwidth = 0; 136 if (type.isa<IntegerType>()) 137 bitwidth = type.getIntOrFloatBitWidth(); 138 if (auto vectorType = type.dyn_cast<VectorType>()) 139 bitwidth = vectorType.getElementTypeBitWidth(); 140 if (bitwidth != 32) 141 return failure(); 142 143 Location loc = countOp.getLoc(); 144 Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc); 145 Value msb = 146 rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand()); 147 // We need to subtract from 31 given that the index is from the least 148 // significant bit. 149 rewriter.replaceOpWithNewOp<spirv::ISubOp>(countOp, val31, msb); 150 return success(); 151 } 152 }; 153 154 /// Converts math.expm1 to SPIR-V ops. 155 /// 156 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to 157 /// these operations. 158 template <typename ExpOp> 159 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> { 160 using OpConversionPattern::OpConversionPattern; 161 162 LogicalResult 163 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, 164 ConversionPatternRewriter &rewriter) const override { 165 assert(adaptor.getOperands().size() == 1); 166 Location loc = operation.getLoc(); 167 auto type = this->getTypeConverter()->convertType(operation.getType()); 168 auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand()); 169 auto one = spirv::ConstantOp::getOne(type, loc, rewriter); 170 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one); 171 return success(); 172 } 173 }; 174 175 /// Converts math.log1p to SPIR-V ops. 176 /// 177 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to 178 /// these operations. 179 template <typename LogOp> 180 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { 181 using OpConversionPattern::OpConversionPattern; 182 183 LogicalResult 184 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, 185 ConversionPatternRewriter &rewriter) const override { 186 assert(adaptor.getOperands().size() == 1); 187 Location loc = operation.getLoc(); 188 auto type = this->getTypeConverter()->convertType(operation.getType()); 189 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); 190 auto onePlus = 191 rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand()); 192 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus); 193 return success(); 194 } 195 }; 196 } // namespace 197 198 //===----------------------------------------------------------------------===// 199 // Pattern population 200 //===----------------------------------------------------------------------===// 201 202 namespace mlir { 203 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 204 RewritePatternSet &patterns) { 205 // Core patterns 206 patterns.add<CopySignPattern>(typeConverter, patterns.getContext()); 207 208 // GLSL patterns 209 patterns 210 .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>, 211 ExpM1OpPattern<spirv::GLSLExpOp>, 212 spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>, 213 spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>, 214 spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>, 215 spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>, 216 spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>, 217 spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>, 218 spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>, 219 spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>, 220 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>, 221 spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>, 222 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>, 223 spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>( 224 typeConverter, patterns.getContext()); 225 226 // OpenCL patterns 227 patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>, 228 spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>, 229 spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>, 230 spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>, 231 spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>, 232 spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>, 233 spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>, 234 spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>, 235 spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>, 236 spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>, 237 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>, 238 spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>, 239 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>, 240 spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>( 241 typeConverter, patterns.getContext()); 242 } 243 244 } // namespace mlir 245