1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Math dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "../SPIRVCommon/Pattern.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "llvm/Support/Debug.h" 19 20 #define DEBUG_TYPE "math-to-spirv-pattern" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // Operation conversion 26 //===----------------------------------------------------------------------===// 27 28 // Note that DRR cannot be used for the patterns in this file: we may need to 29 // convert type along the way, which requires ConversionPattern. DRR generates 30 // normal RewritePattern. 31 32 namespace { 33 /// Converts math.log1p to SPIR-V ops. 34 /// 35 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to 36 /// these operations. 37 template <typename LogOp> 38 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { 39 public: 40 using OpConversionPattern<math::Log1pOp>::OpConversionPattern; 41 42 LogicalResult 43 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, 44 ConversionPatternRewriter &rewriter) const override { 45 assert(adaptor.getOperands().size() == 1); 46 Location loc = operation.getLoc(); 47 auto type = 48 this->getTypeConverter()->convertType(operation.getOperand().getType()); 49 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); 50 auto onePlus = 51 rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]); 52 rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus); 53 return success(); 54 } 55 }; 56 } // namespace 57 58 //===----------------------------------------------------------------------===// 59 // Pattern population 60 //===----------------------------------------------------------------------===// 61 62 namespace mlir { 63 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 64 RewritePatternSet &patterns) { 65 66 // GLSL patterns 67 patterns.add< 68 Log1pOpPattern<spirv::GLSLLogOp>, 69 spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::GLSLFAbsOp>, 70 spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::GLSLCeilOp>, 71 spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>, 72 spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>, 73 spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::GLSLFloorOp>, 74 spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>, 75 spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>, 76 spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>, 77 spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>, 78 spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>, 79 spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>( 80 typeConverter, patterns.getContext()); 81 82 // OpenCL patterns 83 patterns.add<Log1pOpPattern<spirv::OCLLogOp>, 84 spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::OCLFAbsOp>, 85 spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::OCLCeilOp>, 86 spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::OCLCosOp>, 87 spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::OCLExpOp>, 88 spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::OCLFloorOp>, 89 spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::OCLLogOp>, 90 spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::OCLPowOp>, 91 spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>, 92 spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::OCLSinOp>, 93 spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::OCLSqrtOp>, 94 spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::OCLTanhOp>>( 95 typeConverter, patterns.getContext()); 96 } 97 98 } // namespace mlir 99