1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/IR/TypeUtilities.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/Support/Debug.h"
22 
23 #define DEBUG_TYPE "math-to-spirv-pattern"
24 
25 using namespace mlir;
26 
27 //===----------------------------------------------------------------------===//
28 // Utility functions
29 //===----------------------------------------------------------------------===//
30 
31 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
32 /// given type is not a 32-bit scalar/vector type.
getScalarOrVectorI32Constant(Type type,int value,OpBuilder & builder,Location loc)33 static Value getScalarOrVectorI32Constant(Type type, int value,
34                                           OpBuilder &builder, Location loc) {
35   if (auto vectorType = type.dyn_cast<VectorType>()) {
36     if (!vectorType.getElementType().isInteger(32))
37       return nullptr;
38     SmallVector<int> values(vectorType.getNumElements(), value);
39     return builder.create<spirv::ConstantOp>(loc, type,
40                                              builder.getI32VectorAttr(values));
41   }
42   if (type.isInteger(32))
43     return builder.create<spirv::ConstantOp>(loc, type,
44                                              builder.getI32IntegerAttr(value));
45 
46   return nullptr;
47 }
48 
49 //===----------------------------------------------------------------------===//
50 // Operation conversion
51 //===----------------------------------------------------------------------===//
52 
53 // Note that DRR cannot be used for the patterns in this file: we may need to
54 // convert type along the way, which requires ConversionPattern. DRR generates
55 // normal RewritePattern.
56 
57 namespace {
58 /// Converts math.copysign to SPIR-V ops.
59 class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
60   using OpConversionPattern::OpConversionPattern;
61 
62   LogicalResult
matchAndRewrite(math::CopySignOp copySignOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const63   matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
64                   ConversionPatternRewriter &rewriter) const override {
65     auto type = getTypeConverter()->convertType(copySignOp.getType());
66     if (!type)
67       return failure();
68 
69     FloatType floatType;
70     if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) {
71       floatType = scalarType;
72     } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
73       floatType = vectorType.getElementType().cast<FloatType>();
74     } else {
75       return failure();
76     }
77 
78     Location loc = copySignOp.getLoc();
79     int bitwidth = floatType.getWidth();
80     Type intType = rewriter.getIntegerType(bitwidth);
81     uint64_t intValue = uint64_t(1) << (bitwidth - 1);
82 
83     Value signMask = rewriter.create<spirv::ConstantOp>(
84         loc, intType, rewriter.getIntegerAttr(intType, intValue));
85     Value valueMask = rewriter.create<spirv::ConstantOp>(
86         loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
87 
88     if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
89       assert(vectorType.getRank() == 1);
90       int count = vectorType.getNumElements();
91       intType = VectorType::get(count, intType);
92 
93       SmallVector<Value> signSplat(count, signMask);
94       signMask =
95           rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
96 
97       SmallVector<Value> valueSplat(count, valueMask);
98       valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
99                                                                valueSplat);
100     }
101 
102     Value lhsCast =
103         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
104     Value rhsCast =
105         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
106 
107     Value value = rewriter.create<spirv::BitwiseAndOp>(
108         loc, intType, ValueRange{lhsCast, valueMask});
109     Value sign = rewriter.create<spirv::BitwiseAndOp>(
110         loc, intType, ValueRange{rhsCast, signMask});
111 
112     Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
113                                                        ValueRange{value, sign});
114     rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
115     return success();
116   }
117 };
118 
119 /// Converts math.ctlz to SPIR-V ops.
120 ///
121 /// SPIR-V does not have a direct operations for counting leading zeros. If
122 /// Shader capability is supported, we can leverage GL FindUMsb to calculate
123 /// it.
124 class CountLeadingZerosPattern final
125     : public OpConversionPattern<math::CountLeadingZerosOp> {
126   using OpConversionPattern::OpConversionPattern;
127 
128   LogicalResult
matchAndRewrite(math::CountLeadingZerosOp countOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const129   matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
130                   ConversionPatternRewriter &rewriter) const override {
131     auto type = getTypeConverter()->convertType(countOp.getType());
132     if (!type)
133       return failure();
134 
135     // We can only support 32-bit integer types for now.
136     unsigned bitwidth = 0;
137     if (type.isa<IntegerType>())
138       bitwidth = type.getIntOrFloatBitWidth();
139     if (auto vectorType = type.dyn_cast<VectorType>())
140       bitwidth = vectorType.getElementTypeBitWidth();
141     if (bitwidth != 32)
142       return failure();
143 
144     Location loc = countOp.getLoc();
145     Value input = adaptor.getOperand();
146     Value val1 = getScalarOrVectorI32Constant(type, 1, rewriter, loc);
147     Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
148     Value val32 = getScalarOrVectorI32Constant(type, 32, rewriter, loc);
149 
150     Value msb = rewriter.create<spirv::GLFindUMsbOp>(loc, input);
151     // We need to subtract from 31 given that the index returned by GLSL
152     // FindUMsb is counted from the least significant bit. Theoretically this
153     // also gives the correct result even if the integer has all zero bits, in
154     // which case GL FindUMsb would return -1.
155     Value subMsb = rewriter.create<spirv::ISubOp>(loc, val31, msb);
156     // However, certain Vulkan implementations have driver bugs for the corner
157     // case where the input is zero. And.. it can be smart to optimize a select
158     // only involving the corner case. So separately compute the result when the
159     // input is either zero or one.
160     Value subInput = rewriter.create<spirv::ISubOp>(loc, val32, input);
161     Value cmp = rewriter.create<spirv::ULessThanEqualOp>(loc, input, val1);
162     rewriter.replaceOpWithNewOp<spirv::SelectOp>(countOp, cmp, subInput,
163                                                  subMsb);
164     return success();
165   }
166 };
167 
168 /// Converts math.expm1 to SPIR-V ops.
169 ///
170 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
171 /// these operations.
172 template <typename ExpOp>
173 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
174   using OpConversionPattern::OpConversionPattern;
175 
176   LogicalResult
matchAndRewrite__anon99c81a7f0111::ExpM1OpPattern177   matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
178                   ConversionPatternRewriter &rewriter) const override {
179     assert(adaptor.getOperands().size() == 1);
180     Location loc = operation.getLoc();
181     auto type = this->getTypeConverter()->convertType(operation.getType());
182     auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
183     auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
184     rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
185     return success();
186   }
187 };
188 
189 /// Converts math.log1p to SPIR-V ops.
190 ///
191 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
192 /// these operations.
193 template <typename LogOp>
194 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
195   using OpConversionPattern::OpConversionPattern;
196 
197   LogicalResult
matchAndRewrite__anon99c81a7f0111::Log1pOpPattern198   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
199                   ConversionPatternRewriter &rewriter) const override {
200     assert(adaptor.getOperands().size() == 1);
201     Location loc = operation.getLoc();
202     auto type = this->getTypeConverter()->convertType(operation.getType());
203     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
204     auto onePlus =
205         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
206     rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
207     return success();
208   }
209 };
210 
211 /// Converts math.powf to SPIRV-Ops.
212 struct PowFOpPattern final : public OpConversionPattern<math::PowFOp> {
213   using OpConversionPattern::OpConversionPattern;
214 
215   LogicalResult
matchAndRewrite__anon99c81a7f0111::PowFOpPattern216   matchAndRewrite(math::PowFOp powfOp, OpAdaptor adaptor,
217                   ConversionPatternRewriter &rewriter) const override {
218     auto dstType = getTypeConverter()->convertType(powfOp.getType());
219     if (!dstType)
220       return failure();
221 
222     // Per GL Pow extended instruction spec:
223     // "Result is undefined if x < 0. Result is undefined if x = 0 and y <= 0."
224     Location loc = powfOp.getLoc();
225     Value zero =
226         spirv::ConstantOp::getZero(adaptor.getLhs().getType(), loc, rewriter);
227     Value lessThan =
228         rewriter.create<spirv::FOrdLessThanOp>(loc, adaptor.getLhs(), zero);
229     Value abs = rewriter.create<spirv::GLFAbsOp>(loc, adaptor.getLhs());
230     Value pow = rewriter.create<spirv::GLPowOp>(loc, abs, adaptor.getRhs());
231     Value negate = rewriter.create<spirv::FNegateOp>(loc, pow);
232     rewriter.replaceOpWithNewOp<spirv::SelectOp>(powfOp, lessThan, negate, pow);
233     return success();
234   }
235 };
236 
237 /// Converts math.round to GLSL SPIRV extended ops.
238 struct RoundOpPattern final : public OpConversionPattern<math::RoundOp> {
239   using OpConversionPattern::OpConversionPattern;
240 
241   LogicalResult
matchAndRewrite__anon99c81a7f0111::RoundOpPattern242   matchAndRewrite(math::RoundOp roundOp, OpAdaptor adaptor,
243                   ConversionPatternRewriter &rewriter) const override {
244     Location loc = roundOp.getLoc();
245     auto operand = roundOp.getOperand();
246     auto ty = operand.getType();
247     auto ety = getElementTypeOrSelf(ty);
248 
249     auto zero = spirv::ConstantOp::getZero(ty, loc, rewriter);
250     auto one = spirv::ConstantOp::getOne(ty, loc, rewriter);
251     Value half;
252     if (VectorType vty = ty.dyn_cast<VectorType>()) {
253       half = rewriter.create<spirv::ConstantOp>(
254           loc, vty,
255           DenseElementsAttr::get(vty,
256                                  rewriter.getFloatAttr(ety, 0.5).getValue()));
257     } else {
258       half = rewriter.create<spirv::ConstantOp>(
259           loc, ty, rewriter.getFloatAttr(ety, 0.5));
260     }
261 
262     auto abs = rewriter.create<spirv::GLFAbsOp>(loc, operand);
263     auto floor = rewriter.create<spirv::GLFloorOp>(loc, abs);
264     auto sub = rewriter.create<spirv::FSubOp>(loc, abs, floor);
265     auto greater =
266         rewriter.create<spirv::FOrdGreaterThanEqualOp>(loc, sub, half);
267     auto select = rewriter.create<spirv::SelectOp>(loc, greater, one, zero);
268     auto add = rewriter.create<spirv::FAddOp>(loc, floor, select);
269     rewriter.replaceOpWithNewOp<math::CopySignOp>(roundOp, add, operand);
270     return success();
271   }
272 };
273 
274 } // namespace
275 
276 //===----------------------------------------------------------------------===//
277 // Pattern population
278 //===----------------------------------------------------------------------===//
279 
280 namespace mlir {
populateMathToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)281 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
282                                  RewritePatternSet &patterns) {
283   // Core patterns
284   patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
285 
286   // GLSL patterns
287   patterns
288       .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLLogOp>,
289            ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
290            spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLFAbsOp>,
291            spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
292            spirv::ElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
293            spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
294            spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLFloorOp>,
295            spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLFmaOp>,
296            spirv::ElementwiseOpPattern<math::LogOp, spirv::GLLogOp>,
297            spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLInverseSqrtOp>,
298            spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSinOp>,
299            spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSqrtOp>,
300            spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLTanhOp>>(
301           typeConverter, patterns.getContext());
302 
303   // OpenCL patterns
304   patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
305                spirv::ElementwiseOpPattern<math::AbsOp, spirv::CLFAbsOp>,
306                spirv::ElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
307                spirv::ElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
308                spirv::ElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,
309                spirv::ElementwiseOpPattern<math::ExpOp, spirv::CLExpOp>,
310                spirv::ElementwiseOpPattern<math::FloorOp, spirv::CLFloorOp>,
311                spirv::ElementwiseOpPattern<math::FmaOp, spirv::CLFmaOp>,
312                spirv::ElementwiseOpPattern<math::LogOp, spirv::CLLogOp>,
313                spirv::ElementwiseOpPattern<math::PowFOp, spirv::CLPowOp>,
314                spirv::ElementwiseOpPattern<math::RoundOp, spirv::CLRoundOp>,
315                spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::CLRsqrtOp>,
316                spirv::ElementwiseOpPattern<math::SinOp, spirv::CLSinOp>,
317                spirv::ElementwiseOpPattern<math::SqrtOp, spirv::CLSqrtOp>,
318                spirv::ElementwiseOpPattern<math::TanhOp, spirv::CLTanhOp>>(
319       typeConverter, patterns.getContext());
320 }
321 
322 } // namespace mlir
323