1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "math-to-spirv-pattern"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Operation conversion
26 //===----------------------------------------------------------------------===//
27 
28 // Note that DRR cannot be used for the patterns in this file: we may need to
29 // convert type along the way, which requires ConversionPattern. DRR generates
30 // normal RewritePattern.
31 
32 namespace {
33 /// Converts math.log1p to SPIR-V ops.
34 ///
35 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
36 /// these operations.
37 template <typename LogOp>
38 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
39 public:
40   using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
41 
42   LogicalResult
43   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
44                   ConversionPatternRewriter &rewriter) const override {
45     assert(adaptor.getOperands().size() == 1);
46     Location loc = operation.getLoc();
47     auto type =
48         this->getTypeConverter()->convertType(operation.getOperand().getType());
49     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
50     auto onePlus =
51         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
52     rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
53     return success();
54   }
55 };
56 } // namespace
57 
58 //===----------------------------------------------------------------------===//
59 // Pattern population
60 //===----------------------------------------------------------------------===//
61 
62 namespace mlir {
63 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
64                                  RewritePatternSet &patterns) {
65 
66   // GLSL patterns
67   patterns
68       .add<Log1pOpPattern<spirv::GLSLLogOp>,
69            spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
70            spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
71            spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
72            spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
73            spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
74            spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
75            spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
76            spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
77            spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
78            spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
79            spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>,
80            spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>>(
81           typeConverter, patterns.getContext());
82 
83   // OpenCL patterns
84   patterns.add<Log1pOpPattern<spirv::OCLLogOp>,
85                spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
86                spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
87                spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,
88                spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
89                spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
90                spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
91                spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
92                spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
93                spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
94                spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
95                spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
96                spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
97       typeConverter, patterns.getContext());
98 }
99 
100 } // namespace mlir
101