1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/Support/Debug.h"
21 
22 #define DEBUG_TYPE "math-to-spirv-pattern"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Utility functions
28 //===----------------------------------------------------------------------===//
29 
30 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
31 /// given type is not a 32-bit scalar/vector type.
32 static Value getScalarOrVectorI32Constant(Type type, int value,
33                                           OpBuilder &builder, Location loc) {
34   if (auto vectorType = type.dyn_cast<VectorType>()) {
35     if (!vectorType.getElementType().isInteger(32))
36       return nullptr;
37     SmallVector<int> values(vectorType.getNumElements(), value);
38     return builder.create<spirv::ConstantOp>(loc, type,
39                                              builder.getI32VectorAttr(values));
40   }
41   if (type.isInteger(32))
42     return builder.create<spirv::ConstantOp>(loc, type,
43                                              builder.getI32IntegerAttr(value));
44 
45   return nullptr;
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // Operation conversion
50 //===----------------------------------------------------------------------===//
51 
52 // Note that DRR cannot be used for the patterns in this file: we may need to
53 // convert type along the way, which requires ConversionPattern. DRR generates
54 // normal RewritePattern.
55 
56 namespace {
57 /// Converts math.copysign to SPIR-V ops.
58 class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
59   using OpConversionPattern::OpConversionPattern;
60 
61   LogicalResult
62   matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
63                   ConversionPatternRewriter &rewriter) const override {
64     auto type = getTypeConverter()->convertType(copySignOp.getType());
65     if (!type)
66       return failure();
67 
68     FloatType floatType;
69     if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) {
70       floatType = scalarType;
71     } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
72       floatType = vectorType.getElementType().cast<FloatType>();
73     } else {
74       return failure();
75     }
76 
77     Location loc = copySignOp.getLoc();
78     int bitwidth = floatType.getWidth();
79     Type intType = rewriter.getIntegerType(bitwidth);
80     uint64_t intValue = uint64_t(1) << (bitwidth - 1);
81 
82     Value signMask = rewriter.create<spirv::ConstantOp>(
83         loc, intType, rewriter.getIntegerAttr(intType, intValue));
84     Value valueMask = rewriter.create<spirv::ConstantOp>(
85         loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
86 
87     if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
88       assert(vectorType.getRank() == 1);
89       int count = vectorType.getNumElements();
90       intType = VectorType::get(count, intType);
91 
92       SmallVector<Value> signSplat(count, signMask);
93       signMask =
94           rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
95 
96       SmallVector<Value> valueSplat(count, valueMask);
97       valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
98                                                                valueSplat);
99     }
100 
101     Value lhsCast =
102         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
103     Value rhsCast =
104         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
105 
106     Value value = rewriter.create<spirv::BitwiseAndOp>(
107         loc, intType, ValueRange{lhsCast, valueMask});
108     Value sign = rewriter.create<spirv::BitwiseAndOp>(
109         loc, intType, ValueRange{rhsCast, signMask});
110 
111     Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
112                                                        ValueRange{value, sign});
113     rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
114     return success();
115   }
116 };
117 
118 /// Converts math.ctlz to SPIR-V ops.
119 ///
120 /// SPIR-V does not have a direct operations for counting leading zeros. If
121 /// Shader capability is supported, we can leverage GLSL FindUMsb to calculate
122 /// it.
123 class CountLeadingZerosPattern final
124     : public OpConversionPattern<math::CountLeadingZerosOp> {
125   using OpConversionPattern::OpConversionPattern;
126 
127   LogicalResult
128   matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
129                   ConversionPatternRewriter &rewriter) const override {
130     auto type = getTypeConverter()->convertType(countOp.getType());
131     if (!type)
132       return failure();
133 
134     // We can only support 32-bit integer types for now.
135     unsigned bitwidth = 0;
136     if (type.isa<IntegerType>())
137       bitwidth = type.getIntOrFloatBitWidth();
138     if (auto vectorType = type.dyn_cast<VectorType>())
139       bitwidth = vectorType.getElementTypeBitWidth();
140     if (bitwidth != 32)
141       return failure();
142 
143     Location loc = countOp.getLoc();
144     Value input = adaptor.getOperand();
145     Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
146     Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
147     Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
148 
149     Value msb = rewriter.create<spirv::GLSLFindUMsbOp>(loc, input);
150     // We need to subtract from 31 given that the index returned by GLSL
151     // FindUMsb is counted from the least significant bit. Theoretically this
152     // also gives the correct result even if the integer has all zero bits, in
153     // which case GLSL FindUMsb would return -1.
154     Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
155     // However, certain Vulkan implementations have driver bugs for the corner
156     // case where the input is zero. And.. it can be smart to optimize a select
157     // only involving the corner case. So separately compute the result when the
158     // input is either zero or one.
159     Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
160     Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
161     rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
162                                                  subMsb);
163     return success();
164   }
165 };
166 
167 /// Converts math.expm1 to SPIR-V ops.
168 ///
169 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
170 /// these operations.
171 template <typename ExpOp>
172 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
173   using OpConversionPattern::OpConversionPattern;
174 
175   LogicalResult
176   matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
177                   ConversionPatternRewriter &rewriter) const override {
178     assert(adaptor.getOperands().size() == 1);
179     Location loc = operation.getLoc();
180     auto type = this->getTypeConverter()->convertType(operation.getType());
181     auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
182     auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
183     rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
184     return success();
185   }
186 };
187 
188 /// Converts math.log1p to SPIR-V ops.
189 ///
190 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
191 /// these operations.
192 template <typename LogOp>
193 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
194   using OpConversionPattern::OpConversionPattern;
195 
196   LogicalResult
197   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
198                   ConversionPatternRewriter &rewriter) const override {
199     assert(adaptor.getOperands().size() == 1);
200     Location loc = operation.getLoc();
201     auto type = this->getTypeConverter()->convertType(operation.getType());
202     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
203     auto onePlus =
204         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
205     rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
206     return success();
207   }
208 };
209 
210 /// Converts math.powf to SPIRV-Ops.
211 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
212   using OpConversionPattern::OpConversionPattern;
213 
214   LogicalResult
215   matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
216                   ConversionPatternRewriter &rewriter) const override {
217     auto dstType = getTypeConverter()->convertType(powfOp.getType());
218     if (!dstType)
219       return failure();
220 
221     // Per GLSL Pow extended instruction spec:
222     // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
223     Location loc = powfOp.getLoc();
224     Value zero =
225         spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter);
226     Value lessThan =
227         rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
228     Value abs = rewriter.create<spirv::GLSLFAbsOp>(loc, adaptor.getLhs());
229     Value pow = rewriter.create<spirv::GLSLPowOp>(loc, abs, adaptor.getRhs());
230     Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
231     rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, lessThan, negate, pow);
232     return success();
233   }
234 };
235 
236 } // namespace
237 
238 //===----------------------------------------------------------------------===//
239 // Pattern population
240 //===----------------------------------------------------------------------===//
241 
242 namespace mlir {
243 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
244                                  RewritePatternSet &patterns) {
245   // Core patterns
246   patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
247 
248   // GLSL patterns
249   patterns
250       .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>,
251            ExpM1OpPattern<spirv::GLSLExpOp>, PowFOpPattern,
252            spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
253            spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
254            spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
255            spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
256            spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
257            spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>,
258            spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
259            spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
260            spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
261            spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
262            spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
263           typeConverter, patterns.getContext());
264 
265   // OpenCL patterns
266   patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>,
267                spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
268                spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
269                spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,
270                spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
271                spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
272                spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
273                spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
274                spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
275                spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
276                spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
277                spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
278                spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
279                spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
280       typeConverter, patterns.getContext());
281 }
282 
283 } // namespace mlir
284