1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "llvm/Support/Debug.h"
20 
21 #define DEBUG_TYPE "math-to-spirv-pattern"
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // Operation conversion
27 //===----------------------------------------------------------------------===//
28 
29 // Note that DRR cannot be used for the patterns in this file: we may need to
30 // convert type along the way, which requires ConversionPattern. DRR generates
31 // normal RewritePattern.
32 
33 namespace {
34 /// Converts math.copysign to SPIR-V ops.
35 class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
36   using OpConversionPattern::OpConversionPattern;
37 
38   LogicalResult
39   matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
40                   ConversionPatternRewriter &rewriter) const override {
41     auto type = getTypeConverter()->convertType(copySignOp.getType());
42     if (!type)
43       return failure();
44 
45     FloatType floatType;
46     if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) {
47       floatType = scalarType;
48     } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
49       floatType = vectorType.getElementType().cast<FloatType>();
50     } else {
51       return failure();
52     }
53 
54     Location loc = copySignOp.getLoc();
55     int bitwidth = floatType.getWidth();
56     Type intType = rewriter.getIntegerType(bitwidth);
57 
58     Value signMask = rewriter.create<spirv::ConstantOp>(
59         loc, intType, rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1))));
60     Value valueMask = rewriter.create<spirv::ConstantOp>(
61         loc, intType,
62         rewriter.getIntegerAttr(intType, (1u << (bitwidth - 1)) - 1u));
63 
64     if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
65       assert(vectorType.getRank() == 1);
66       int count = vectorType.getNumElements();
67       intType = VectorType::get(count, intType);
68 
69       SmallVector<Value> signSplat(count, signMask);
70       signMask =
71           rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
72 
73       SmallVector<Value> valueSplat(count, valueMask);
74       valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
75                                                                valueSplat);
76     }
77 
78     Value lhsCast =
79         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
80     Value rhsCast =
81         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
82 
83     Value value = rewriter.create<spirv::BitwiseAndOp>(
84         loc, intType, ValueRange{lhsCast, valueMask});
85     Value sign = rewriter.create<spirv::BitwiseAndOp>(
86         loc, intType, ValueRange{rhsCast, signMask});
87 
88     Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
89                                                        ValueRange{value, sign});
90     rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
91     return success();
92   }
93 };
94 
95 /// Converts math.expm1 to SPIR-V ops.
96 ///
97 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
98 /// these operations.
99 template <typename ExpOp>
100 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
101   using OpConversionPattern::OpConversionPattern;
102 
103   LogicalResult
104   matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
105                   ConversionPatternRewriter &rewriter) const override {
106     assert(adaptor.getOperands().size() == 1);
107     Location loc = operation.getLoc();
108     auto type = this->getTypeConverter()->convertType(operation.getType());
109     auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
110     auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
111     rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
112     return success();
113   }
114 };
115 
116 /// Converts math.log1p to SPIR-V ops.
117 ///
118 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
119 /// these operations.
120 template <typename LogOp>
121 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
122   using OpConversionPattern::OpConversionPattern;
123 
124   LogicalResult
125   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
126                   ConversionPatternRewriter &rewriter) const override {
127     assert(adaptor.getOperands().size() == 1);
128     Location loc = operation.getLoc();
129     auto type = this->getTypeConverter()->convertType(operation.getType());
130     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
131     auto onePlus =
132         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
133     rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
134     return success();
135   }
136 };
137 } // namespace
138 
139 //===----------------------------------------------------------------------===//
140 // Pattern population
141 //===----------------------------------------------------------------------===//
142 
143 namespace mlir {
144 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
145                                  RewritePatternSet &patterns) {
146   // Core patterns
147   patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
148 
149   // GLSL patterns
150   patterns
151       .add<Log1pOpPattern<spirv::GLSLLogOp>, ExpM1OpPattern<spirv::GLSLExpOp>,
152            spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
153            spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
154            spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
155            spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
156            spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
157            spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>,
158            spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
159            spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
160            spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
161            spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
162            spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
163            spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
164           typeConverter, patterns.getContext());
165 
166   // OpenCL patterns
167   patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>,
168                spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
169                spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
170                spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,
171                spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
172                spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
173                spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
174                spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
175                spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
176                spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
177                spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
178                spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
179                spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
180                spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
181       typeConverter, patterns.getContext());
182 }
183 
184 } // namespace mlir
185