1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "math-to-spirv-pattern"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Operation conversion
26 //===----------------------------------------------------------------------===//
27 
28 // Note that DRR cannot be used for the patterns in this file: we may need to
29 // convert type along the way, which requires ConversionPattern. DRR generates
30 // normal RewritePattern.
31 
32 namespace {
33 /// Converts math.expm1 to SPIR-V ops.
34 ///
35 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
36 /// these operations.
37 template <typename ExpOp>
38 class ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
39 public:
40   using OpConversionPattern<math::ExpM1Op>::OpConversionPattern;
41 
42   LogicalResult
43   matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
44                   ConversionPatternRewriter &rewriter) const override {
45     assert(adaptor.getOperands().size() == 1);
46     Location loc = operation.getLoc();
47     auto type = this->getTypeConverter()->convertType(operation.getType());
48     auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
49     auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
50     rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
51     return success();
52   }
53 };
54 
55 /// Converts math.log1p to SPIR-V ops.
56 ///
57 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
58 /// these operations.
59 template <typename LogOp>
60 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
61 public:
62   using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
63 
64   LogicalResult
65   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
66                   ConversionPatternRewriter &rewriter) const override {
67     assert(adaptor.getOperands().size() == 1);
68     Location loc = operation.getLoc();
69     auto type = this->getTypeConverter()->convertType(operation.getType());
70     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
71     auto onePlus =
72         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
73     rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
74     return success();
75   }
76 };
77 } // namespace
78 
79 //===----------------------------------------------------------------------===//
80 // Pattern population
81 //===----------------------------------------------------------------------===//
82 
83 namespace mlir {
84 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
85                                  RewritePatternSet &patterns) {
86 
87   // GLSL patterns
88   patterns
89       .add<Log1pOpPattern<spirv::GLSLLogOp>, ExpM1OpPattern<spirv::GLSLExpOp>,
90            spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
91            spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
92            spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
93            spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
94            spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
95            spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>,
96            spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
97            spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
98            spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
99            spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
100            spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
101            spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
102           typeConverter, patterns.getContext());
103 
104   // OpenCL patterns
105   patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>,
106                spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
107                spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
108                spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,
109                spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
110                spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
111                spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
112                spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
113                spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
114                spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
115                spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
116                spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
117                spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
118                spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
119       typeConverter, patterns.getContext());
120 }
121 
122 } // namespace mlir
123