1 //===- MathToSPIRV.cpp - Math to SPIR-V Patterns --------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert Math dialect to SPIR-V dialect.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "../SPIRVCommon/Pattern.h"
14 #include "mlir/Dialect/Math/IR/Math.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
16 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
17 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
18 #include "mlir/IR/BuiltinTypes.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/Support/Debug.h"
21 
22 #define DEBUG_TYPE "math-to-spirv-pattern"
23 
24 using namespace mlir;
25 
26 //===----------------------------------------------------------------------===//
27 // Utility functions
28 //===----------------------------------------------------------------------===//
29 
30 /// Creates a 32-bit scalar/vector integer constant. Returns nullptr if the
31 /// given type is not a 32-bit scalar/vector type.
32 static Value getScalarOrVectorI32Constant(Type type, int value,
33                                           OpBuilder &builder, Location loc) {
34   if (auto vectorType = type.dyn_cast<VectorType>()) {
35     if (!vectorType.getElementType().isInteger(32))
36       return nullptr;
37     SmallVector<int> values(vectorType.getNumElements(), value);
38     return builder.create<spirv::ConstantOp>(loc, type,
39                                              builder.getI32VectorAttr(values));
40   }
41   if (type.isInteger(32))
42     return builder.create<spirv::ConstantOp>(loc, type,
43                                              builder.getI32IntegerAttr(value));
44 
45   return nullptr;
46 }
47 
48 //===----------------------------------------------------------------------===//
49 // Operation conversion
50 //===----------------------------------------------------------------------===//
51 
52 // Note that DRR cannot be used for the patterns in this file: we may need to
53 // convert type along the way, which requires ConversionPattern. DRR generates
54 // normal RewritePattern.
55 
56 namespace {
57 /// Converts math.copysign to SPIR-V ops.
58 class CopySignPattern final : public OpConversionPattern<math::CopySignOp> {
59   using OpConversionPattern::OpConversionPattern;
60 
61   LogicalResult
62   matchAndRewrite(math::CopySignOp copySignOp, OpAdaptor adaptor,
63                   ConversionPatternRewriter &rewriter) const override {
64     auto type = getTypeConverter()->convertType(copySignOp.getType());
65     if (!type)
66       return failure();
67 
68     FloatType floatType;
69     if (auto scalarType = copySignOp.getType().dyn_cast<FloatType>()) {
70       floatType = scalarType;
71     } else if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
72       floatType = vectorType.getElementType().cast<FloatType>();
73     } else {
74       return failure();
75     }
76 
77     Location loc = copySignOp.getLoc();
78     int bitwidth = floatType.getWidth();
79     Type intType = rewriter.getIntegerType(bitwidth);
80     uint64_t intValue = uint64_t(1) << (bitwidth - 1);
81 
82     Value signMask = rewriter.create<spirv::ConstantOp>(
83         loc, intType, rewriter.getIntegerAttr(intType, intValue));
84     Value valueMask = rewriter.create<spirv::ConstantOp>(
85         loc, intType, rewriter.getIntegerAttr(intType, intValue - 1u));
86 
87     if (auto vectorType = copySignOp.getType().dyn_cast<VectorType>()) {
88       assert(vectorType.getRank() == 1);
89       int count = vectorType.getNumElements();
90       intType = VectorType::get(count, intType);
91 
92       SmallVector<Value> signSplat(count, signMask);
93       signMask =
94           rewriter.create<spirv::CompositeConstructOp>(loc, intType, signSplat);
95 
96       SmallVector<Value> valueSplat(count, valueMask);
97       valueMask = rewriter.create<spirv::CompositeConstructOp>(loc, intType,
98                                                                valueSplat);
99     }
100 
101     Value lhsCast =
102         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getLhs());
103     Value rhsCast =
104         rewriter.create<spirv::BitcastOp>(loc, intType, adaptor.getRhs());
105 
106     Value value = rewriter.create<spirv::BitwiseAndOp>(
107         loc, intType, ValueRange{lhsCast, valueMask});
108     Value sign = rewriter.create<spirv::BitwiseAndOp>(
109         loc, intType, ValueRange{rhsCast, signMask});
110 
111     Value result = rewriter.create<spirv::BitwiseOrOp>(loc, intType,
112                                                        ValueRange{value, sign});
113     rewriter.replaceOpWithNewOp<spirv::BitcastOp>(copySignOp, type, result);
114     return success();
115   }
116 };
117 
118 /// Converts math.ctlz to SPIR-V ops.
119 ///
120 /// SPIR-V does not have a direct operations for counting leading zeros. If
121 /// Shader capability is supported, we can leverage GLSL FindUMsb to calculate
122 /// it.
123 class CountLeadingZerosPattern final
124     : public OpConversionPattern<math::CountLeadingZerosOp> {
125   using OpConversionPattern::OpConversionPattern;
126 
127   LogicalResult
128   matchAndRewrite(math::CountLeadingZerosOp countOp, OpAdaptor adaptor,
129                   ConversionPatternRewriter &rewriter) const override {
130     auto type = getTypeConverter()->convertType(countOp.getType());
131     if (!type)
132       return failure();
133 
134     // We can only support 32-bit integer types for now.
135     unsigned bitwidth = 0;
136     if (type.isa<IntegerType>())
137       bitwidth = type.getIntOrFloatBitWidth();
138     if (auto vectorType = type.dyn_cast<VectorType>())
139       bitwidth = vectorType.getElementTypeBitWidth();
140     if (bitwidth != 32)
141       return failure();
142 
143     Location loc = countOp.getLoc();
144     Value val31 = getScalarOrVectorI32Constant(type, 31, rewriter, loc);
145     Value msb =
146         rewriter.create<spirv::GLSLFindUMsbOp>(loc, adaptor.getOperand());
147     // We need to subtract from 31 given that the index is from the least
148     // significant bit.
149     rewriter.replaceOpWithNewOp<spirv::ISubOp>(countOp, val31, msb);
150     return success();
151   }
152 };
153 
154 /// Converts math.expm1 to SPIR-V ops.
155 ///
156 /// SPIR-V does not have a direct operations for exp(x)-1. Explicitly lower to
157 /// these operations.
158 template <typename ExpOp>
159 struct ExpM1OpPattern final : public OpConversionPattern<math::ExpM1Op> {
160   using OpConversionPattern::OpConversionPattern;
161 
162   LogicalResult
163   matchAndRewrite(math::ExpM1Op operation, OpAdaptor adaptor,
164                   ConversionPatternRewriter &rewriter) const override {
165     assert(adaptor.getOperands().size() == 1);
166     Location loc = operation.getLoc();
167     auto type = this->getTypeConverter()->convertType(operation.getType());
168     auto exp = rewriter.create<ExpOp>(loc, type, adaptor.getOperand());
169     auto one = spirv::ConstantOp::getOne(type, loc, rewriter);
170     rewriter.replaceOpWithNewOp<spirv::FSubOp>(operation, exp, one);
171     return success();
172   }
173 };
174 
175 /// Converts math.log1p to SPIR-V ops.
176 ///
177 /// SPIR-V does not have a direct operations for log(1+x). Explicitly lower to
178 /// these operations.
179 template <typename LogOp>
180 struct Log1pOpPattern final : public OpConversionPattern<math::Log1pOp> {
181   using OpConversionPattern::OpConversionPattern;
182 
183   LogicalResult
184   matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor,
185                   ConversionPatternRewriter &rewriter) const override {
186     assert(adaptor.getOperands().size() == 1);
187     Location loc = operation.getLoc();
188     auto type = this->getTypeConverter()->convertType(operation.getType());
189     auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter);
190     auto onePlus =
191         rewriter.create<spirv::FAddOp>(loc, one, adaptor.getOperand());
192     rewriter.replaceOpWithNewOp<LogOp>(operation, type, onePlus);
193     return success();
194   }
195 };
196 } // namespace
197 
198 //===----------------------------------------------------------------------===//
199 // Pattern population
200 //===----------------------------------------------------------------------===//
201 
202 namespace mlir {
203 void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
204                                  RewritePatternSet &patterns) {
205   // Core patterns
206   patterns.add<CopySignPattern>(typeConverter, patterns.getContext());
207 
208   // GLSL patterns
209   patterns
210       .add<CountLeadingZerosPattern, Log1pOpPattern<spirv::GLSLLogOp>,
211            ExpM1OpPattern<spirv::GLSLExpOp>,
212            spirv::ElementwiseOpPattern<math::AbsOp, spirv::GLSLFAbsOp>,
213            spirv::ElementwiseOpPattern<math::CeilOp, spirv::GLSLCeilOp>,
214            spirv::ElementwiseOpPattern<math::CosOp, spirv::GLSLCosOp>,
215            spirv::ElementwiseOpPattern<math::ExpOp, spirv::GLSLExpOp>,
216            spirv::ElementwiseOpPattern<math::FloorOp, spirv::GLSLFloorOp>,
217            spirv::ElementwiseOpPattern<math::FmaOp, spirv::GLSLFmaOp>,
218            spirv::ElementwiseOpPattern<math::LogOp, spirv::GLSLLogOp>,
219            spirv::ElementwiseOpPattern<math::PowFOp, spirv::GLSLPowOp>,
220            spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::GLSLInverseSqrtOp>,
221            spirv::ElementwiseOpPattern<math::SinOp, spirv::GLSLSinOp>,
222            spirv::ElementwiseOpPattern<math::SqrtOp, spirv::GLSLSqrtOp>,
223            spirv::ElementwiseOpPattern<math::TanhOp, spirv::GLSLTanhOp>>(
224           typeConverter, patterns.getContext());
225 
226   // OpenCL patterns
227   patterns.add<Log1pOpPattern<spirv::OCLLogOp>, ExpM1OpPattern<spirv::OCLExpOp>,
228                spirv::ElementwiseOpPattern<math::AbsOp, spirv::OCLFAbsOp>,
229                spirv::ElementwiseOpPattern<math::CeilOp, spirv::OCLCeilOp>,
230                spirv::ElementwiseOpPattern<math::CosOp, spirv::OCLCosOp>,
231                spirv::ElementwiseOpPattern<math::ErfOp, spirv::OCLErfOp>,
232                spirv::ElementwiseOpPattern<math::ExpOp, spirv::OCLExpOp>,
233                spirv::ElementwiseOpPattern<math::FloorOp, spirv::OCLFloorOp>,
234                spirv::ElementwiseOpPattern<math::FmaOp, spirv::OCLFmaOp>,
235                spirv::ElementwiseOpPattern<math::LogOp, spirv::OCLLogOp>,
236                spirv::ElementwiseOpPattern<math::PowFOp, spirv::OCLPowOp>,
237                spirv::ElementwiseOpPattern<math::RsqrtOp, spirv::OCLRsqrtOp>,
238                spirv::ElementwiseOpPattern<math::SinOp, spirv::OCLSinOp>,
239                spirv::ElementwiseOpPattern<math::SqrtOp, spirv::OCLSqrtOp>,
240                spirv::ElementwiseOpPattern<math::TanhOp, spirv::OCLTanhOp>>(
241       typeConverter, patterns.getContext());
242 }
243 
244 } // namespace mlir
245