1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements patterns to convert Math dialect to SPIR-V dialect. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "../SPIRVCommon/Pattern.h" 14 #include "mlir/Dialect/Math/IR/Math.h" 15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" 16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" 17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" 18 #include "llvm/Support/Debug.h" 19 20 #define DEBUG_TYPE "math-to-spirv-pattern" 21 22 using namespace mlir; 23 24 //===----------------------------------------------------------------------===// 25 // Operation conversion 26 //===----------------------------------------------------------------------===// 27 28 // Note that DRR cannot be used for the patterns in this file: we may need to 29 // convert type along the way, which requires ConversionPattern. DRR generates 30 // normal RewritePattern. 31 32 namespace { 33 /// Converts math.log1p to SPIR-V ops. 34 /// 35 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to 36 /// these operations. 37 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> { 38 public: 39 using OpConversionPattern<math::Log1pOp>::OpConversionPattern; 40 41 LogicalResult 42 matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, 43 ConversionPatternRewriter &rewriter) const override { 44 assert(adaptor.getOperands().size() == 1); 45 Location loc = operation.getLoc(); 46 auto type = 47 this->getTypeConverter()->convertType(operation.operand().getType()); 48 auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); 49 auto onePlus = 50 rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]); 51 rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus); 52 return success(); 53 } 54 }; 55 } // namespace 56 57 //===----------------------------------------------------------------------===// 58 // Pattern population 59 //===----------------------------------------------------------------------===// 60 61 namespace mlir { 62 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter, 63 RewritePatternSet &patterns) { 64 patterns.add< 65 Log1pOpPattern, 66 spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::GLSLFAbsOp>, 67 spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::GLSLCeilOp>, 68 spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>, 69 spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>, 70 spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::GLSLFloorOp>, 71 spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>, 72 spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>, 73 spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>, 74 spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>, 75 spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>, 76 spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>( 77 typeConverter, patterns.getContext()); 78 } 79 80 } // namespace mlir 81