1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "llvm/Support/Debug.h"
19 
20 #define DEBUG_TYPE "math-to-spirv-pattern"
21 
22 using namespace mlir;
23 
24 //===----------------------------------------------------------------------===//
25 // Operation conversion
26 //===----------------------------------------------------------------------===//
27 
28 // Note that DRR cannot be used for the patterns in this file: we may need to
29 // convert type along the way, which requires ConversionPattern. DRR generates
30 // normal RewritePattern.
31 
32 namespace {
33 /// Converts math.log1p to SPIR-V ops.
34 ///
35 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
36 /// these operations.
37 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
38 public:
39   using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
40 
41   LogicalResult
42   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
43                   ConversionPatternRewriter &rewriter) const override {
44     assert(adaptor.getOperands().size() == 1);
45     Location loc = operation.getLoc();
46     auto type =
47         this->getTypeConverter()->convertType(operation.operand().getType());
48     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
49     auto onePlus =
50         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
51     rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
52     return success();
53   }
54 };
55 } // namespace
56 
57 //===----------------------------------------------------------------------===//
58 // Pattern population
59 //===----------------------------------------------------------------------===//
60 
61 namespace mlir {
62 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
63                                  RewritePatternSet &patterns) {
64   patterns.add<
65       Log1pOpPattern,
66       spirv::UnaryAndBinaryOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
67       spirv::UnaryAndBinaryOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
68       spirv::UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
69       spirv::UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
70       spirv::UnaryAndBinaryOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
71       spirv::UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
72       spirv::UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
73       spirv::UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
74       spirv::UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
75       spirv::UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
76       spirv::UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
77       typeConverter, patterns.getContext());
78 }
79 
80 } // namespace mlir
81