1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "math-to-spirv-pattern"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Operation conversion
26 //===----------------------------------------------------------------------===//
27 
28 // Note that DRR cannot be used for the patterns in this file: we may need to
29 // convert type along the way, which requires ConversionPattern. DRR generates
30 // normal RewritePattern.
31 
32 namespace {
33 /// Converts math.log1p to SPIR-V ops.
34 ///
35 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
36 /// these operations.
37 template <typename LogOp>
38 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
39 public:
40   using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
41 
42   LogicalResult
43   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
44                   ConversionPatternRewriter &rewriter) const override {
45     assert(adaptor.getOperands().size() == 1);
46     Location loc = operation.getLoc();
47     auto type =
48         this->getTypeConverter()->convertType(operation.getOperand().getType());
49     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
50     auto onePlus =
51         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
52     rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
53     return success();
54   }
55 };
56 } // namespace
57 
58 //===----------------------------------------------------------------------===//
59 // Pattern population
60 //===----------------------------------------------------------------------===//
61 
62 namespace mlir {
63 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
64                                  RewritePatternSet &patterns) {
65 
66   // GLSL patterns
67   patterns.add<
68       Log1pOpPattern<spirv::GLSLLogOp>,
69       spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
70       spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
71       spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
72       spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
73       spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
74       spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
75       spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
76       spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
77       spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
78       spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
79       spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
80       typeConverter, patterns.getContext());
81 
82   // OpenCL patterns
83   patterns.add<Log1pOpPattern<spirv::OCLLogOp>,
84                spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
85                spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::OCLCeilOp>,
86                spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::OCLCosOp>,
87                spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::OCLExpOp>,
88                spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::OCLFloorOp>,
89                spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::OCLLogOp>,
90                spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::OCLPowOp>,
91                spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
92                spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::OCLSinOp>,
93                spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
94                spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
95       typeConverter, patterns.getContext());
96 }
97 
98 } // namespace mlir
99