1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/Math/IR/Math.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17 #include "llvm/Support/Debug.h"
18 
19 #define DEBUG_TYPE "math-to-spirv-pattern"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Operation conversion
25 //===----------------------------------------------------------------------===//
26 
27 // Note that DRR cannot be used for the patterns in this file: we may need to
28 // convert type along the way, which requires ConversionPattern. DRR generates
29 // normal RewritePattern.
30 
31 namespace {
32 
33 /// Converts unary and binary standard operations to SPIR-V operations.
34 template <typename StdOp, typename SPIRVOp>
35 class UnaryAndBinaryOpPattern final : public OpConversionPattern<StdOp> {
36 public:
37   using OpConversionPattern<StdOp>::OpConversionPattern;
38 
39   LogicalResult
40   matchAndRewrite(StdOp operation, ArrayRef<Value> operands,
41                   ConversionPatternRewriter &rewriter) const override {
42     assert(operands.size() <= 2);
43     auto dstType = this->getTypeConverter()->convertType(operation.getType());
44     if (!dstType)
45       return failure();
46     if (SPIRVOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&
47         dstType != operation.getType()) {
48       return operation.emitError(
49           "bitwidth emulation is not implemented yet on unsigned op");
50     }
51     rewriter.template replaceOpWithNewOp<SPIRVOp>(operation, dstType, operands);
52     return success();
53   }
54 };
55 
56 /// Converts math.log1p to SPIR-V ops.
57 ///
58 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
59 /// these operations.
60 class Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
61 public:
62   using OpConversionPattern<math::Log1pOp>::OpConversionPattern;
63 
64   LogicalResult
65   matchAndRewrite(math::Log1pOp operation, ArrayRef<Value> operands,
66                   ConversionPatternRewriter &rewriter) const override {
67     assert(operands.size() == 1);
68     Location loc = operation.getLoc();
69     auto type =
70         this->getTypeConverter()->convertType(operation.operand().getType());
71     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
72     auto onePlus = rewriter.create<spirv::FAddOp>(loc, one, operands[0]);
73     rewriter.replaceOpWithNewOp<spirv::GLSLLogOp>(operation, type, onePlus);
74     return success();
75   }
76 };
77 
78 } // namespace
79 
80 //===----------------------------------------------------------------------===//
81 // Pattern population
82 //===----------------------------------------------------------------------===//
83 
84 namespace mlir {
85 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
86                                  RewritePatternSet &patterns) {
87   patterns.add<Log1pOpPattern,
88                UnaryAndBinaryOpPattern<math::CosOp, spirv::GLSLCosOp>,
89                UnaryAndBinaryOpPattern<math::ExpOp, spirv::GLSLExpOp>,
90                UnaryAndBinaryOpPattern<math::LogOp, spirv::GLSLLogOp>,
91                UnaryAndBinaryOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
92                UnaryAndBinaryOpPattern<math::PowFOp, spirv::GLSLPowOp>,
93                UnaryAndBinaryOpPattern<math::SinOp, spirv::GLSLSinOp>,
94                UnaryAndBinaryOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
95                UnaryAndBinaryOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
96       typeConverter, patterns.getContext());
97 }
98 
99 } // namespace mlir
100