1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Math/IR/Math.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17 #include "llvm/Support/Debug.h"
18 
19 #define DEBUG_TYPE "math-to-spirv-pattern"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Operation conversion
25 //===----------------------------------------------------------------------===//
26 
27 // Note that DRR cannot be used for the patterns in this file: we may need to
28 // convert type along the way, which requires ConversionPattern. DRR generates
29 // normal RewritePattern.
30 
31 namespace {
32 
33 /// Converts unary and binary standard operations to SPIR-V operations.
34 template <typename StdOp, typename SPIRVOp>
35 class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
36 public:
37   using OpConversionPattern<StdOp>::OpConversionPattern;
38 
39   LogicalResult
40   matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor,
41                   ConversionPatternRewriter &rewriter) const override {
42     assert(adaptor.getOperands().size() <= 2);
43     auto dstType = this->getTypeConverter()->convertType(operation.getType());
44     if (!dstType)
45       return failure();
46     if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
47         dstType != operation.getType()) {
48       return operation.emitError(
49           "bitwidth emulation is not implemented yet on unsigned op");
50     }
51     rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType,
52                                                   adaptor.getOperands());
53     return success();
54   }
55 };
56 
57 /// Converts math.log1p to SPIR-V ops.
58 ///
59 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
60 /// these operations.
61 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
62 public:
63   using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
64 
65   LogicalResult
66   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
67                   ConversionPatternRewriter &rewriter) const override {
68     assert(adaptor.getOperands().size() == 1);
69     Location loc = operation.getLoc();
70     auto type =
71         this->getTypeConverter()->convertType(operation.operand().getType());
72     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
73     auto onePlus =
74         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperands()[0]);
75     rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
76     return success();
77   }
78 };
79 
80 } // namespace
81 
82 //===----------------------------------------------------------------------===//
83 // Pattern population
84 //===----------------------------------------------------------------------===//
85 
86 namespace mlir {
87 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
88                                  RewritePatternSet &patterns) {
89   patterns.add<Log1pOpPattern,
90                UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
91                UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
92                UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
93                UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
94                UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
95                UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
96                UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
97                UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
98       typeConverter, patterns.getContext());
99 }
100 
101 } // namespace mlir
102