1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Math dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "../SPIRVCommon/Pattern.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "llvm/Support/Debug.h" 19 20 #define DEBUG_TYPE "math-to-spirv-pattern" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // Operation conversion 26 //===----------------------------------------------------------------------===// 27 28 // Note that DRR cannot be used for the patterns in this file: we may need to 29 // convert type along the way, which requires ConversionPattern. DRR generates 30 // normal RewritePattern. 31 32 namespace { 33 /// Converts math.expm1 to SPIR-V ops. 34 /// 35 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to 36 /// these operations. 37 template <typename ExpOp> 38 class ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> { 39 public: 40 using OpConversionPattern<math::ExpM1Op>::OpConversionPattern; 41 42 LogicalResult 43 matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor, 44 ConversionPatternRewriter &rewriter) const override { 45 assert(adaptor.getOperands().size() == 1); 46 Location loc = operation.getLoc(); 47 auto type = this->getTypeConverter()->convertType(operation.getType()); 48 auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand()); 49 auto one = spirv::ConstantOp::getOne(type, loc, rewriter); 50 rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one); 51 return success(); 52 } 53 }; 54 55 /// Converts math.log1p to SPIR-V ops. 56 /// 57 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to 58 /// these operations. 59 template <typename LogOp> 60 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { 61 public: 62 using OpConversionPattern<math::Log1pOp>::OpConversionPattern; 63 64 LogicalResult 65 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, 66 ConversionPatternRewriter &rewriter) const override { 67 assert(adaptor.getOperands().size() == 1); 68 Location loc = operation.getLoc(); 69 auto type = this->getTypeConverter()->convertType(operation.getType()); 70 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); 71 auto onePlus = 72 rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand()); 73 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus); 74 return success(); 75 } 76 }; 77 } // namespace 78 79 //===----------------------------------------------------------------------===// 80 // Pattern population 81 //===----------------------------------------------------------------------===// 82 83 namespace mlir { 84 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 85 RewritePatternSet &patterns) { 86 87 // GLSL patterns 88 patterns 89 .add<Log1pOpPattern<spirv::GLSLLogOp>, ExpM1OpPattern<spirv::GLSLExpOp>, 90 spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>, 91 spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>, 92 spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>, 93 spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>, 94 spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>, 95 spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>, 96 spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>, 97 spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>, 98 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>, 99 spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>, 100 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>, 101 spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>( 102 typeConverter, patterns.getContext()); 103 104 // OpenCL patterns 105 patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>, 106 spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>, 107 spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>, 108 spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>, 109 spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>, 110 spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>, 111 spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>, 112 spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>, 113 spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>, 114 spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>, 115 spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>, 116 spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>, 117 spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>, 118 spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>( 119 typeConverter, patterns.getContext()); 120 } 121 122 } // namespace mlir 123