1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Math dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "../SPIRVCommon/Pattern.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "mlir/IR/BuiltinTypes.h" 19 #include "llvm/Support/Debug.h" 20 21 #define DEBUG_TYPE "math-to-spirv-pattern" 22 23 using namespace mlir; 24 25 //===----------------------------------------------------------------------===// 26 // Operation conversion 27 //===----------------------------------------------------------------------===// 28 29 // Note that DRR cannot be used for the patterns in this file: we may need to 30 // convert type along the way, which requires ConversionPattern. DRR generates 31 // normal RewritePattern. 32 33 namespace { 34 /// Converts math.copysign to SPIR-V ops. 35 class CopySignPattern final : public OpConversionPattern<math::CopySignOp> { 36 using OpConversionPattern::OpConversionPattern; 37 38 LogicalResult 39 matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor, 40 ConversionPatternRewriter &rewriter) const override { 41 auto type = getTypeConverter()->convertType(copySignOp.getType()); 42 if (!type) 43 return failure(); 44 45 FloatType floatType; 46 if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) { 47 floatType = scalarType; 48 } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { 49 floatType = vectorType.getElementType().cast<FloatType>(); 50 } else { 51 return failure(); 52 } 53 54 Location loc = copySignOp.getLoc(); 55 int bitwidth = floatType.getWidth(); 56 Type intType = rewriter.getIntegerType(bitwidth); 57 58 Value signMask = rewriter.create<spirv::ConstantOp>( 59 loc, intType, rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1)))); 60 Value valueMask = rewriter.create<spirv::ConstantOp>( 61 loc, intType, 62 rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1)) - 1u)); 63 64 if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) { 65 assert(vectorType.getRank() == 1); 66 int count = vectorType.getNumElements(); 67 intType = VectorType::get(count, intType); 68 69 SmallVector<Value> signSplat(count, signMask); 70 signMask = 71 rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat); 72 73 SmallVector<Value> valueSplat(count, valueMask); 74 valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType, 75 valueSplat); 76 } 77 78 Value lhsCast = 79 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs()); 80 Value rhsCast = 81 rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs()); 82 83 Value value = rewriter.create<spirv::BitwiseAndOp>( 84 loc, intType, ValueRange{lhsCast, valueMask}); 85 Value sign = rewriter.create<spirv::BitwiseAndOp>( 86 loc, intType, ValueRange{rhsCast, signMask}); 87 88 Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType, 89 ValueRange{value, sign}); 90 rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result); 91 return success(); 92 } 93 }; 94 95 /// Converts math.expm1 to SPIR-V ops. 96 /// 97 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to 98 /// these operations. 99 template <typename ExpOp> 100 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> { 101 using OpConversionPattern::OpConversionPattern; 102 103 LogicalResult 104 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, 105 ConversionPatternRewriter &rewriter) const override { 106 assert(adaptor.getOperands().size() == 1); 107 Location loc = operation.getLoc(); 108 auto type = this->getTypeConverter()->convertType(operation.getType()); 109 auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand()); 110 auto one = spirv::ConstantOp::getOne(type, loc, rewriter); 111 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one); 112 return success(); 113 } 114 }; 115 116 /// Converts math.log1p to SPIR-V ops. 117 /// 118 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to 119 /// these operations. 120 template <typename LogOp> 121 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { 122 using OpConversionPattern::OpConversionPattern; 123 124 LogicalResult 125 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, 126 ConversionPatternRewriter &rewriter) const override { 127 assert(adaptor.getOperands().size() == 1); 128 Location loc = operation.getLoc(); 129 auto type = this->getTypeConverter()->convertType(operation.getType()); 130 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); 131 auto onePlus = 132 rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand()); 133 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus); 134 return success(); 135 } 136 }; 137 } // namespace 138 139 //===----------------------------------------------------------------------===// 140 // Pattern population 141 //===----------------------------------------------------------------------===// 142 143 namespace mlir { 144 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 145 RewritePatternSet &patterns) { 146 // Core patterns 147 patterns.add<CopySignPattern>(typeConverter, patterns.getContext()); 148 149 // GLSL patterns 150 patterns 151 .add<Log1pOpPattern<spirv::GLSLLogOp>, ExpM1OpPattern<spirv::GLSLExpOp>, 152 spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>, 153 spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>, 154 spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>, 155 spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>, 156 spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>, 157 spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>, 158 spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>, 159 spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>, 160 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>, 161 spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>, 162 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>, 163 spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>( 164 typeConverter, patterns.getContext()); 165 166 // OpenCL patterns 167 patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>, 168 spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>, 169 spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>, 170 spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>, 171 spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>, 172 spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>, 173 spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>, 174 spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>, 175 spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>, 176 spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>, 177 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>, 178 spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>, 179 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>, 180 spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>( 181 typeConverter, patterns.getContext()); 182 } 183 184 } // namespace mlir 185