1a54f4eaeSMogball //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2a54f4eaeSMogball //
3a54f4eaeSMogball // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a54f4eaeSMogball // See https://llvm.org/LICENSE.txt for license information.
5a54f4eaeSMogball // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a54f4eaeSMogball //
7a54f4eaeSMogball //===----------------------------------------------------------------------===//
8a54f4eaeSMogball
9a54f4eaeSMogball #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
10a54f4eaeSMogball #include "../PassDetail.h"
11a54f4eaeSMogball #include "../SPIRVCommon/Pattern.h"
123ba66435SRiver Riddle #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
13a54f4eaeSMogball #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14a54f4eaeSMogball #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15a54f4eaeSMogball #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16a54f4eaeSMogball #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17e1e0ecb9SLei Zhang #include "mlir/IR/BuiltinTypes.h"
18a54f4eaeSMogball #include "llvm/Support/Debug.h"
19a54f4eaeSMogball
20a54f4eaeSMogball #define DEBUG_TYPE "arith-to-spirv-pattern"
21a54f4eaeSMogball
22a54f4eaeSMogball using namespace mlir;
23a54f4eaeSMogball
24a54f4eaeSMogball //===----------------------------------------------------------------------===//
25a54f4eaeSMogball // Operation Conversion
26a54f4eaeSMogball //===----------------------------------------------------------------------===//
27a54f4eaeSMogball
28a54f4eaeSMogball namespace {
29a54f4eaeSMogball
30a54f4eaeSMogball /// Converts composite arith.constant operation to spv.Constant.
31a54f4eaeSMogball struct ConstantCompositeOpPattern final
32a54f4eaeSMogball : public OpConversionPattern<arith::ConstantOp> {
33a54f4eaeSMogball using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
34a54f4eaeSMogball
35a54f4eaeSMogball LogicalResult
36a54f4eaeSMogball matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
37a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
38a54f4eaeSMogball };
39a54f4eaeSMogball
40a54f4eaeSMogball /// Converts scalar arith.constant operation to spv.Constant.
41a54f4eaeSMogball struct ConstantScalarOpPattern final
42a54f4eaeSMogball : public OpConversionPattern<arith::ConstantOp> {
43a54f4eaeSMogball using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
44a54f4eaeSMogball
45a54f4eaeSMogball LogicalResult
46a54f4eaeSMogball matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
47a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
48a54f4eaeSMogball };
49a54f4eaeSMogball
508dae0b6bSButygin /// Converts arith.remsi to GLSL SPIR-V ops.
51a54f4eaeSMogball ///
52a54f4eaeSMogball /// This cannot be merged into the template unary/binary pattern due to Vulkan
53a54f4eaeSMogball /// restrictions over spv.SRem and spv.SMod.
54*52b630daSJakub Kuderski struct RemSIOpGLPattern final : public OpConversionPattern<arith::RemSIOp> {
558dae0b6bSButygin using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
568dae0b6bSButygin
578dae0b6bSButygin LogicalResult
588dae0b6bSButygin matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
598dae0b6bSButygin ConversionPatternRewriter &rewriter) const override;
608dae0b6bSButygin };
618dae0b6bSButygin
628dae0b6bSButygin /// Converts arith.remsi to OpenCL SPIR-V ops.
633930cc68SJakub Kuderski struct RemSIOpCLPattern final : public OpConversionPattern<arith::RemSIOp> {
64a54f4eaeSMogball using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
65a54f4eaeSMogball
66a54f4eaeSMogball LogicalResult
67a54f4eaeSMogball matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
68a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
69a54f4eaeSMogball };
70a54f4eaeSMogball
71a54f4eaeSMogball /// Converts bitwise operations to SPIR-V operations. This is a special pattern
72a54f4eaeSMogball /// other than the BinaryOpPatternPattern because if the operands are boolean
73a54f4eaeSMogball /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
74a54f4eaeSMogball /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
75a54f4eaeSMogball template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
76a54f4eaeSMogball struct BitwiseOpPattern final : public OpConversionPattern<Op> {
77a54f4eaeSMogball using OpConversionPattern<Op>::OpConversionPattern;
78a54f4eaeSMogball
79a54f4eaeSMogball LogicalResult
80a54f4eaeSMogball matchAndRewrite(Op op, typename Op::Adaptor adaptor,
81a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
82a54f4eaeSMogball };
83a54f4eaeSMogball
84a54f4eaeSMogball /// Converts arith.xori to SPIR-V operations.
85a54f4eaeSMogball struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
86a54f4eaeSMogball using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
87a54f4eaeSMogball
88a54f4eaeSMogball LogicalResult
89a54f4eaeSMogball matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
90a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
91a54f4eaeSMogball };
92a54f4eaeSMogball
93a54f4eaeSMogball /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
94a54f4eaeSMogball /// vector of i1.
95a54f4eaeSMogball struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
96a54f4eaeSMogball using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
97a54f4eaeSMogball
98a54f4eaeSMogball LogicalResult
99a54f4eaeSMogball matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
100a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
101a54f4eaeSMogball };
102a54f4eaeSMogball
103a54f4eaeSMogball /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
104a54f4eaeSMogball /// i1.
105a54f4eaeSMogball struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
106a54f4eaeSMogball using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
107a54f4eaeSMogball
108a54f4eaeSMogball LogicalResult
109a54f4eaeSMogball matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
110a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
111a54f4eaeSMogball };
112a54f4eaeSMogball
113a54f4eaeSMogball /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
114a54f4eaeSMogball /// i1.
115a54f4eaeSMogball struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
116a54f4eaeSMogball using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
117a54f4eaeSMogball
118a54f4eaeSMogball LogicalResult
119a54f4eaeSMogball matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
120a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
121a54f4eaeSMogball };
122a54f4eaeSMogball
123a54f4eaeSMogball /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
124a54f4eaeSMogball /// i1.
125a54f4eaeSMogball struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
126a54f4eaeSMogball using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
127a54f4eaeSMogball
128a54f4eaeSMogball LogicalResult
129a54f4eaeSMogball matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
130a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
131a54f4eaeSMogball };
132a54f4eaeSMogball
133a54f4eaeSMogball /// Converts type-casting standard operations to SPIR-V operations.
134a54f4eaeSMogball template <typename Op, typename SPIRVOp>
135a54f4eaeSMogball struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
136a54f4eaeSMogball using OpConversionPattern<Op>::OpConversionPattern;
137a54f4eaeSMogball
138a54f4eaeSMogball LogicalResult
139a54f4eaeSMogball matchAndRewrite(Op op, typename Op::Adaptor adaptor,
140a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
141a54f4eaeSMogball };
142a54f4eaeSMogball
143a54f4eaeSMogball /// Converts integer compare operation on i1 type operands to SPIR-V ops.
144a54f4eaeSMogball class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
145a54f4eaeSMogball public:
146a54f4eaeSMogball using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
147a54f4eaeSMogball
148a54f4eaeSMogball LogicalResult
149a54f4eaeSMogball matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
150a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
151a54f4eaeSMogball };
152a54f4eaeSMogball
153a54f4eaeSMogball /// Converts integer compare operation to SPIR-V ops.
154a54f4eaeSMogball class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
155a54f4eaeSMogball public:
156a54f4eaeSMogball using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
157a54f4eaeSMogball
158a54f4eaeSMogball LogicalResult
159a54f4eaeSMogball matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
160a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
161a54f4eaeSMogball };
162a54f4eaeSMogball
163a54f4eaeSMogball /// Converts floating-point comparison operations to SPIR-V ops.
164a54f4eaeSMogball class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
165a54f4eaeSMogball public:
166a54f4eaeSMogball using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
167a54f4eaeSMogball
168a54f4eaeSMogball LogicalResult
169a54f4eaeSMogball matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
170a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
171a54f4eaeSMogball };
172a54f4eaeSMogball
173a54f4eaeSMogball /// Converts floating point NaN check to SPIR-V ops. This pattern requires
174a54f4eaeSMogball /// Kernel capability.
175a54f4eaeSMogball class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
176a54f4eaeSMogball public:
177a54f4eaeSMogball using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
178a54f4eaeSMogball
179a54f4eaeSMogball LogicalResult
180a54f4eaeSMogball matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
181a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
182a54f4eaeSMogball };
183a54f4eaeSMogball
184a54f4eaeSMogball /// Converts floating point NaN check to SPIR-V ops. This pattern does not
185a54f4eaeSMogball /// require additional capability.
186a54f4eaeSMogball class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
187a54f4eaeSMogball public:
188a54f4eaeSMogball using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
189a54f4eaeSMogball
190a54f4eaeSMogball LogicalResult
191a54f4eaeSMogball matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
192a54f4eaeSMogball ConversionPatternRewriter &rewriter) const override;
193a54f4eaeSMogball };
194a54f4eaeSMogball
195dec8af70SRiver Riddle /// Converts arith.select to spv.Select.
196dec8af70SRiver Riddle class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
197dec8af70SRiver Riddle public:
198dec8af70SRiver Riddle using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
199dec8af70SRiver Riddle LogicalResult
200dec8af70SRiver Riddle matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
201dec8af70SRiver Riddle ConversionPatternRewriter &rewriter) const override;
202dec8af70SRiver Riddle };
203dec8af70SRiver Riddle
204be0a7e9fSMehdi Amini } // namespace
205a54f4eaeSMogball
206a54f4eaeSMogball //===----------------------------------------------------------------------===//
207a54f4eaeSMogball // Conversion Helpers
208a54f4eaeSMogball //===----------------------------------------------------------------------===//
209a54f4eaeSMogball
210a54f4eaeSMogball /// Converts the given `srcAttr` into a boolean attribute if it holds an
211a54f4eaeSMogball /// integral value. Returns null attribute if conversion fails.
convertBoolAttr(Attribute srcAttr,Builder builder)212a54f4eaeSMogball static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
213a54f4eaeSMogball if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
214a54f4eaeSMogball return boolAttr;
215a54f4eaeSMogball if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
216a54f4eaeSMogball return builder.getBoolAttr(intAttr.getValue().getBoolValue());
217a54f4eaeSMogball return BoolAttr();
218a54f4eaeSMogball }
219a54f4eaeSMogball
220a54f4eaeSMogball /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
221a54f4eaeSMogball /// Returns null attribute if conversion fails.
convertIntegerAttr(IntegerAttr srcAttr,IntegerType dstType,Builder builder)222a54f4eaeSMogball static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
223a54f4eaeSMogball Builder builder) {
224a54f4eaeSMogball // If the source number uses less active bits than the target bitwidth, then
225a54f4eaeSMogball // it should be safe to convert.
226a54f4eaeSMogball if (srcAttr.getValue().isIntN(dstType.getWidth()))
227a54f4eaeSMogball return builder.getIntegerAttr(dstType, srcAttr.getInt());
228a54f4eaeSMogball
229a54f4eaeSMogball // XXX: Try again by interpreting the source number as a signed value.
230a54f4eaeSMogball // Although integers in the standard dialect are signless, they can represent
231a54f4eaeSMogball // a signed number. It's the operation decides how to interpret. This is
232a54f4eaeSMogball // dangerous, but it seems there is no good way of handling this if we still
233a54f4eaeSMogball // want to change the bitwidth. Emit a message at least.
234a54f4eaeSMogball if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
235a54f4eaeSMogball auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
236a54f4eaeSMogball LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
237a54f4eaeSMogball << dstAttr << "' for type '" << dstType << "'\n");
238a54f4eaeSMogball return dstAttr;
239a54f4eaeSMogball }
240a54f4eaeSMogball
241a54f4eaeSMogball LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
242a54f4eaeSMogball << "' illegal: cannot fit into target type '"
243a54f4eaeSMogball << dstType << "'\n");
244a54f4eaeSMogball return IntegerAttr();
245a54f4eaeSMogball }
246a54f4eaeSMogball
247a54f4eaeSMogball /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
248a54f4eaeSMogball /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
convertFloatAttr(FloatAttr srcAttr,FloatType dstType,Builder builder)249a54f4eaeSMogball static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
250a54f4eaeSMogball Builder builder) {
251a54f4eaeSMogball // Only support converting to float for now.
252a54f4eaeSMogball if (!dstType.isF32())
253a54f4eaeSMogball return FloatAttr();
254a54f4eaeSMogball
255a54f4eaeSMogball // Try to convert the source floating-point number to single precision.
256a54f4eaeSMogball APFloat dstVal = srcAttr.getValue();
257a54f4eaeSMogball bool losesInfo = false;
258a54f4eaeSMogball APFloat::opStatus status =
259a54f4eaeSMogball dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
260a54f4eaeSMogball if (status != APFloat::opOK || losesInfo) {
261a54f4eaeSMogball LLVM_DEBUG(llvm::dbgs()
262a54f4eaeSMogball << srcAttr << " illegal: cannot fit into converted type '"
263a54f4eaeSMogball << dstType << "'\n");
264a54f4eaeSMogball return FloatAttr();
265a54f4eaeSMogball }
266a54f4eaeSMogball
267a54f4eaeSMogball return builder.getF32FloatAttr(dstVal.convertToFloat());
268a54f4eaeSMogball }
269a54f4eaeSMogball
270a54f4eaeSMogball /// Returns true if the given `type` is a boolean scalar or vector type.
isBoolScalarOrVector(Type type)271a54f4eaeSMogball static bool isBoolScalarOrVector(Type type) {
272a54f4eaeSMogball if (type.isInteger(1))
273a54f4eaeSMogball return true;
274a54f4eaeSMogball if (auto vecType = type.dyn_cast<VectorType>())
275a54f4eaeSMogball return vecType.getElementType().isInteger(1);
276a54f4eaeSMogball return false;
277a54f4eaeSMogball }
278a54f4eaeSMogball
279b5192cbeSLei Zhang /// Returns true if scalar/vector type `a` and `b` have the same number of
280b5192cbeSLei Zhang /// bitwidth.
hasSameBitwidth(Type a,Type b)281b5192cbeSLei Zhang static bool hasSameBitwidth(Type a, Type b) {
282b5192cbeSLei Zhang auto getNumBitwidth = [](Type type) {
283b5192cbeSLei Zhang unsigned bw = 0;
284b5192cbeSLei Zhang if (type.isIntOrFloat())
285b5192cbeSLei Zhang bw = type.getIntOrFloatBitWidth();
286b5192cbeSLei Zhang else if (auto vecType = type.dyn_cast<VectorType>())
287b5192cbeSLei Zhang bw = vecType.getElementTypeBitWidth() * vecType.getNumElements();
288b5192cbeSLei Zhang return bw;
289b5192cbeSLei Zhang };
290b5192cbeSLei Zhang unsigned aBW = getNumBitwidth(a);
291b5192cbeSLei Zhang unsigned bBW = getNumBitwidth(b);
292b5192cbeSLei Zhang return aBW != 0 && bBW != 0 && aBW == bBW;
293b5192cbeSLei Zhang }
294b5192cbeSLei Zhang
295a54f4eaeSMogball //===----------------------------------------------------------------------===//
296a54f4eaeSMogball // ConstantOp with composite type
297a54f4eaeSMogball //===----------------------------------------------------------------------===//
298a54f4eaeSMogball
matchAndRewrite(arith::ConstantOp constOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const299a54f4eaeSMogball LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
300a54f4eaeSMogball arith::ConstantOp constOp, OpAdaptor adaptor,
301a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
302a54f4eaeSMogball auto srcType = constOp.getType().dyn_cast<ShapedType>();
30396130b5dSLei Zhang if (!srcType || srcType.getNumElements() == 1)
304a54f4eaeSMogball return failure();
305a54f4eaeSMogball
306cb3aa49eSMogball // arith.constant should only have vector or tenor types.
307a54f4eaeSMogball assert((srcType.isa<VectorType, RankedTensorType>()));
308a54f4eaeSMogball
309a54f4eaeSMogball auto dstType = getTypeConverter()->convertType(srcType);
310a54f4eaeSMogball if (!dstType)
311a54f4eaeSMogball return failure();
312a54f4eaeSMogball
313cfb72fd3SJacques Pienaar auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
314a54f4eaeSMogball if (!dstElementsAttr)
315a54f4eaeSMogball return failure();
316a54f4eaeSMogball
317f1a7e508SLei Zhang ShapedType dstAttrType = dstElementsAttr.getType();
318f1a7e508SLei Zhang
319a54f4eaeSMogball // If the composite type has more than one dimensions, perform linearization.
320a54f4eaeSMogball if (srcType.getRank() > 1) {
321a54f4eaeSMogball if (srcType.isa<RankedTensorType>()) {
322a54f4eaeSMogball dstAttrType = RankedTensorType::get(srcType.getNumElements(),
323a54f4eaeSMogball srcType.getElementType());
324a54f4eaeSMogball dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
325a54f4eaeSMogball } else {
326a54f4eaeSMogball // TODO: add support for large vectors.
327a54f4eaeSMogball return failure();
328a54f4eaeSMogball }
329a54f4eaeSMogball }
330a54f4eaeSMogball
331a54f4eaeSMogball Type srcElemType = srcType.getElementType();
332a54f4eaeSMogball Type dstElemType;
333a54f4eaeSMogball // Tensor types are converted to SPIR-V array types; vector types are
334a54f4eaeSMogball // converted to SPIR-V vector/array types.
335a54f4eaeSMogball if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
336a54f4eaeSMogball dstElemType = arrayType.getElementType();
337a54f4eaeSMogball else
338a54f4eaeSMogball dstElemType = dstType.cast<VectorType>().getElementType();
339a54f4eaeSMogball
340a54f4eaeSMogball // If the source and destination element types are different, perform
341a54f4eaeSMogball // attribute conversion.
342a54f4eaeSMogball if (srcElemType != dstElemType) {
343a54f4eaeSMogball SmallVector<Attribute, 8> elements;
344a54f4eaeSMogball if (srcElemType.isa<FloatType>()) {
345a54f4eaeSMogball for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
346a54f4eaeSMogball FloatAttr dstAttr =
347a54f4eaeSMogball convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
348a54f4eaeSMogball if (!dstAttr)
349a54f4eaeSMogball return failure();
350a54f4eaeSMogball elements.push_back(dstAttr);
351a54f4eaeSMogball }
352a54f4eaeSMogball } else if (srcElemType.isInteger(1)) {
353a54f4eaeSMogball return failure();
354a54f4eaeSMogball } else {
355a54f4eaeSMogball for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
356a54f4eaeSMogball IntegerAttr dstAttr = convertIntegerAttr(
357a54f4eaeSMogball srcAttr, dstElemType.cast<IntegerType>(), rewriter);
358a54f4eaeSMogball if (!dstAttr)
359a54f4eaeSMogball return failure();
360a54f4eaeSMogball elements.push_back(dstAttr);
361a54f4eaeSMogball }
362a54f4eaeSMogball }
363a54f4eaeSMogball
364a54f4eaeSMogball // Unfortunately, we cannot use dialect-specific types for element
365a54f4eaeSMogball // attributes; element attributes only works with builtin types. So we need
366a54f4eaeSMogball // to prepare another converted builtin types for the destination elements
367a54f4eaeSMogball // attribute.
368a54f4eaeSMogball if (dstAttrType.isa<RankedTensorType>())
369a54f4eaeSMogball dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
370a54f4eaeSMogball else
371a54f4eaeSMogball dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
372a54f4eaeSMogball
373a54f4eaeSMogball dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
374a54f4eaeSMogball }
375a54f4eaeSMogball
376a54f4eaeSMogball rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
377a54f4eaeSMogball dstElementsAttr);
378a54f4eaeSMogball return success();
379a54f4eaeSMogball }
380a54f4eaeSMogball
381a54f4eaeSMogball //===----------------------------------------------------------------------===//
382a54f4eaeSMogball // ConstantOp with scalar type
383a54f4eaeSMogball //===----------------------------------------------------------------------===//
384a54f4eaeSMogball
matchAndRewrite(arith::ConstantOp constOp,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const385a54f4eaeSMogball LogicalResult ConstantScalarOpPattern::matchAndRewrite(
386a54f4eaeSMogball arith::ConstantOp constOp, OpAdaptor adaptor,
387a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
388a54f4eaeSMogball Type srcType = constOp.getType();
38996130b5dSLei Zhang if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
39096130b5dSLei Zhang if (shapedType.getNumElements() != 1)
39196130b5dSLei Zhang return failure();
39296130b5dSLei Zhang srcType = shapedType.getElementType();
39396130b5dSLei Zhang }
394a54f4eaeSMogball if (!srcType.isIntOrIndexOrFloat())
395a54f4eaeSMogball return failure();
396a54f4eaeSMogball
39796130b5dSLei Zhang Attribute cstAttr = constOp.getValue();
39896130b5dSLei Zhang if (cstAttr.getType().isa<ShapedType>())
39996130b5dSLei Zhang cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
40096130b5dSLei Zhang
401a54f4eaeSMogball Type dstType = getTypeConverter()->convertType(srcType);
402a54f4eaeSMogball if (!dstType)
403a54f4eaeSMogball return failure();
404a54f4eaeSMogball
405a54f4eaeSMogball // Floating-point types.
406a54f4eaeSMogball if (srcType.isa<FloatType>()) {
40796130b5dSLei Zhang auto srcAttr = cstAttr.cast<FloatAttr>();
408a54f4eaeSMogball auto dstAttr = srcAttr;
409a54f4eaeSMogball
410a54f4eaeSMogball // Floating-point types not supported in the target environment are all
411a54f4eaeSMogball // converted to float type.
412a54f4eaeSMogball if (srcType != dstType) {
413a54f4eaeSMogball dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
414a54f4eaeSMogball if (!dstAttr)
415a54f4eaeSMogball return failure();
416a54f4eaeSMogball }
417a54f4eaeSMogball
418a54f4eaeSMogball rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
419a54f4eaeSMogball return success();
420a54f4eaeSMogball }
421a54f4eaeSMogball
422a54f4eaeSMogball // Bool type.
423a54f4eaeSMogball if (srcType.isInteger(1)) {
424cb3aa49eSMogball // arith.constant can use 0/1 instead of true/false for i1 values. We need
425cb3aa49eSMogball // to handle that here.
42696130b5dSLei Zhang auto dstAttr = convertBoolAttr(cstAttr, rewriter);
427a54f4eaeSMogball if (!dstAttr)
428a54f4eaeSMogball return failure();
429a54f4eaeSMogball rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
430a54f4eaeSMogball return success();
431a54f4eaeSMogball }
432a54f4eaeSMogball
433a54f4eaeSMogball // IndexType or IntegerType. Index values are converted to 32-bit integer
434a54f4eaeSMogball // values when converting to SPIR-V.
43596130b5dSLei Zhang auto srcAttr = cstAttr.cast<IntegerAttr>();
436a54f4eaeSMogball auto dstAttr =
437a54f4eaeSMogball convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
438a54f4eaeSMogball if (!dstAttr)
439a54f4eaeSMogball return failure();
440a54f4eaeSMogball rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
441a54f4eaeSMogball return success();
442a54f4eaeSMogball }
443a54f4eaeSMogball
444a54f4eaeSMogball //===----------------------------------------------------------------------===//
445*52b630daSJakub Kuderski // RemSIOpGLPattern
446a54f4eaeSMogball //===----------------------------------------------------------------------===//
447a54f4eaeSMogball
448a54f4eaeSMogball /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
449a54f4eaeSMogball /// the sign of `signOperand`.
450a54f4eaeSMogball ///
451a54f4eaeSMogball /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
452a54f4eaeSMogball /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
453a54f4eaeSMogball /// the result is undefined." So we cannot directly use spv.SRem/spv.SMod
454a54f4eaeSMogball /// if either operand can be negative. Emulate it via spv.UMod.
4558dae0b6bSButygin template <typename SignedAbsOp>
emulateSignedRemainder(Location loc,Value lhs,Value rhs,Value signOperand,OpBuilder & builder)456a54f4eaeSMogball static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
457a54f4eaeSMogball Value signOperand, OpBuilder &builder) {
458a54f4eaeSMogball assert(lhs.getType() == rhs.getType());
459a54f4eaeSMogball assert(lhs == signOperand || rhs == signOperand);
460a54f4eaeSMogball
461a54f4eaeSMogball Type type = lhs.getType();
462a54f4eaeSMogball
463a54f4eaeSMogball // Calculate the remainder with spv.UMod.
4648dae0b6bSButygin Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
4658dae0b6bSButygin Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
466a54f4eaeSMogball Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
467a54f4eaeSMogball
468a54f4eaeSMogball // Fix the sign.
469a54f4eaeSMogball Value isPositive;
470a54f4eaeSMogball if (lhs == signOperand)
471a54f4eaeSMogball isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
472a54f4eaeSMogball else
473a54f4eaeSMogball isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
474a54f4eaeSMogball Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
475a54f4eaeSMogball return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
476a54f4eaeSMogball }
477a54f4eaeSMogball
478a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::RemSIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const479*52b630daSJakub Kuderski RemSIOpGLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
480a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
481*52b630daSJakub Kuderski Value result = emulateSignedRemainder<spirv::GLSAbsOp>(
4828dae0b6bSButygin op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
4838dae0b6bSButygin adaptor.getOperands()[0], rewriter);
4848dae0b6bSButygin rewriter.replaceOp(op, result);
4858dae0b6bSButygin
4868dae0b6bSButygin return success();
4878dae0b6bSButygin }
4888dae0b6bSButygin
4898dae0b6bSButygin //===----------------------------------------------------------------------===//
4903930cc68SJakub Kuderski // RemSIOpCLPattern
4918dae0b6bSButygin //===----------------------------------------------------------------------===//
4928dae0b6bSButygin
4938dae0b6bSButygin LogicalResult
matchAndRewrite(arith::RemSIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const4943930cc68SJakub Kuderski RemSIOpCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
4958dae0b6bSButygin ConversionPatternRewriter &rewriter) const {
4963930cc68SJakub Kuderski Value result = emulateSignedRemainder<spirv::CLSAbsOp>(
4978dae0b6bSButygin op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
498a54f4eaeSMogball adaptor.getOperands()[0], rewriter);
499a54f4eaeSMogball rewriter.replaceOp(op, result);
500a54f4eaeSMogball
501a54f4eaeSMogball return success();
502a54f4eaeSMogball }
503a54f4eaeSMogball
504a54f4eaeSMogball //===----------------------------------------------------------------------===//
505a54f4eaeSMogball // BitwiseOpPattern
506a54f4eaeSMogball //===----------------------------------------------------------------------===//
507a54f4eaeSMogball
508a54f4eaeSMogball template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
509a54f4eaeSMogball LogicalResult
matchAndRewrite(Op op,typename Op::Adaptor adaptor,ConversionPatternRewriter & rewriter) const510a54f4eaeSMogball BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
511a54f4eaeSMogball Op op, typename Op::Adaptor adaptor,
512a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
513a54f4eaeSMogball assert(adaptor.getOperands().size() == 2);
514a54f4eaeSMogball auto dstType =
515a54f4eaeSMogball this->getTypeConverter()->convertType(op.getResult().getType());
516a54f4eaeSMogball if (!dstType)
517a54f4eaeSMogball return failure();
518a54f4eaeSMogball if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
519a54f4eaeSMogball rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
520a54f4eaeSMogball adaptor.getOperands());
521a54f4eaeSMogball } else {
522a54f4eaeSMogball rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
523a54f4eaeSMogball adaptor.getOperands());
524a54f4eaeSMogball }
525a54f4eaeSMogball return success();
526a54f4eaeSMogball }
527a54f4eaeSMogball
528a54f4eaeSMogball //===----------------------------------------------------------------------===//
529a54f4eaeSMogball // XOrIOpLogicalPattern
530a54f4eaeSMogball //===----------------------------------------------------------------------===//
531a54f4eaeSMogball
matchAndRewrite(arith::XOrIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const532a54f4eaeSMogball LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
533a54f4eaeSMogball arith::XOrIOp op, OpAdaptor adaptor,
534a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
535a54f4eaeSMogball assert(adaptor.getOperands().size() == 2);
536a54f4eaeSMogball
537a54f4eaeSMogball if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
538a54f4eaeSMogball return failure();
539a54f4eaeSMogball
540a54f4eaeSMogball auto dstType = getTypeConverter()->convertType(op.getType());
541a54f4eaeSMogball if (!dstType)
542a54f4eaeSMogball return failure();
543a54f4eaeSMogball rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
544a54f4eaeSMogball adaptor.getOperands());
545a54f4eaeSMogball
546a54f4eaeSMogball return success();
547a54f4eaeSMogball }
548a54f4eaeSMogball
549a54f4eaeSMogball //===----------------------------------------------------------------------===//
550a54f4eaeSMogball // XOrIOpBooleanPattern
551a54f4eaeSMogball //===----------------------------------------------------------------------===//
552a54f4eaeSMogball
matchAndRewrite(arith::XOrIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const553a54f4eaeSMogball LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
554a54f4eaeSMogball arith::XOrIOp op, OpAdaptor adaptor,
555a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
556a54f4eaeSMogball assert(adaptor.getOperands().size() == 2);
557a54f4eaeSMogball
558a54f4eaeSMogball if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
559a54f4eaeSMogball return failure();
560a54f4eaeSMogball
561a54f4eaeSMogball auto dstType = getTypeConverter()->convertType(op.getType());
562a54f4eaeSMogball if (!dstType)
563a54f4eaeSMogball return failure();
564a54f4eaeSMogball rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
565a54f4eaeSMogball adaptor.getOperands());
566a54f4eaeSMogball return success();
567a54f4eaeSMogball }
568a54f4eaeSMogball
569a54f4eaeSMogball //===----------------------------------------------------------------------===//
570a54f4eaeSMogball // UIToFPI1Pattern
571a54f4eaeSMogball //===----------------------------------------------------------------------===//
572a54f4eaeSMogball
573a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::UIToFPOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const574a54f4eaeSMogball UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
575a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
576a54f4eaeSMogball auto srcType = adaptor.getOperands().front().getType();
577a54f4eaeSMogball if (!isBoolScalarOrVector(srcType))
578a54f4eaeSMogball return failure();
579a54f4eaeSMogball
580a54f4eaeSMogball auto dstType =
581a54f4eaeSMogball this->getTypeConverter()->convertType(op.getResult().getType());
582a54f4eaeSMogball Location loc = op.getLoc();
583a54f4eaeSMogball Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
584a54f4eaeSMogball Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
585a54f4eaeSMogball rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
586a54f4eaeSMogball op, dstType, adaptor.getOperands().front(), one, zero);
587a54f4eaeSMogball return success();
588a54f4eaeSMogball }
589a54f4eaeSMogball
590a54f4eaeSMogball //===----------------------------------------------------------------------===//
591a54f4eaeSMogball // ExtUII1Pattern
592a54f4eaeSMogball //===----------------------------------------------------------------------===//
593a54f4eaeSMogball
594a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::ExtUIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const595a54f4eaeSMogball ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
596a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
597a54f4eaeSMogball auto srcType = adaptor.getOperands().front().getType();
598a54f4eaeSMogball if (!isBoolScalarOrVector(srcType))
599a54f4eaeSMogball return failure();
600a54f4eaeSMogball
601a54f4eaeSMogball auto dstType =
602a54f4eaeSMogball this->getTypeConverter()->convertType(op.getResult().getType());
603a54f4eaeSMogball Location loc = op.getLoc();
604a54f4eaeSMogball Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
605a54f4eaeSMogball Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
606a54f4eaeSMogball rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
607a54f4eaeSMogball op, dstType, adaptor.getOperands().front(), one, zero);
608a54f4eaeSMogball return success();
609a54f4eaeSMogball }
610a54f4eaeSMogball
611a54f4eaeSMogball //===----------------------------------------------------------------------===//
612a54f4eaeSMogball // TruncII1Pattern
613a54f4eaeSMogball //===----------------------------------------------------------------------===//
614a54f4eaeSMogball
615a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::TruncIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const616a54f4eaeSMogball TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
617a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
618a54f4eaeSMogball auto dstType =
619a54f4eaeSMogball this->getTypeConverter()->convertType(op.getResult().getType());
620a54f4eaeSMogball if (!isBoolScalarOrVector(dstType))
621a54f4eaeSMogball return failure();
622a54f4eaeSMogball
623a54f4eaeSMogball Location loc = op.getLoc();
624a54f4eaeSMogball auto srcType = adaptor.getOperands().front().getType();
625a54f4eaeSMogball // Check if (x & 1) == 1.
626a54f4eaeSMogball Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
627a54f4eaeSMogball Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
628a54f4eaeSMogball loc, srcType, adaptor.getOperands()[0], mask);
629a54f4eaeSMogball Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
630a54f4eaeSMogball
631a54f4eaeSMogball Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
632a54f4eaeSMogball Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
633a54f4eaeSMogball rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
634a54f4eaeSMogball return success();
635a54f4eaeSMogball }
636a54f4eaeSMogball
637a54f4eaeSMogball //===----------------------------------------------------------------------===//
638a54f4eaeSMogball // TypeCastingOpPattern
639a54f4eaeSMogball //===----------------------------------------------------------------------===//
640a54f4eaeSMogball
641a54f4eaeSMogball template <typename Op, typename SPIRVOp>
matchAndRewrite(Op op,typename Op::Adaptor adaptor,ConversionPatternRewriter & rewriter) const642a54f4eaeSMogball LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
643a54f4eaeSMogball Op op, typename Op::Adaptor adaptor,
644a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
645a54f4eaeSMogball assert(adaptor.getOperands().size() == 1);
646a54f4eaeSMogball auto srcType = adaptor.getOperands().front().getType();
647a54f4eaeSMogball auto dstType =
648a54f4eaeSMogball this->getTypeConverter()->convertType(op.getResult().getType());
649a54f4eaeSMogball if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
650a54f4eaeSMogball return failure();
651a54f4eaeSMogball if (dstType == srcType) {
652a54f4eaeSMogball // Due to type conversion, we are seeing the same source and target type.
653a54f4eaeSMogball // Then we can just erase this operation by forwarding its operand.
654a54f4eaeSMogball rewriter.replaceOp(op, adaptor.getOperands().front());
655a54f4eaeSMogball } else {
656a54f4eaeSMogball rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
657a54f4eaeSMogball adaptor.getOperands());
658a54f4eaeSMogball }
659a54f4eaeSMogball return success();
660a54f4eaeSMogball }
661a54f4eaeSMogball
662a54f4eaeSMogball //===----------------------------------------------------------------------===//
663a54f4eaeSMogball // CmpIOpBooleanPattern
664a54f4eaeSMogball //===----------------------------------------------------------------------===//
665a54f4eaeSMogball
matchAndRewrite(arith::CmpIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const666a54f4eaeSMogball LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
667a54f4eaeSMogball arith::CmpIOp op, OpAdaptor adaptor,
668a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
669e1e0ecb9SLei Zhang Type srcType = op.getLhs().getType();
670e1e0ecb9SLei Zhang if (!isBoolScalarOrVector(srcType))
671e1e0ecb9SLei Zhang return failure();
672e1e0ecb9SLei Zhang Type dstType = getTypeConverter()->convertType(srcType);
673e1e0ecb9SLei Zhang if (!dstType)
674a54f4eaeSMogball return failure();
675a54f4eaeSMogball
676a54f4eaeSMogball switch (op.getPredicate()) {
677e1e0ecb9SLei Zhang case arith::CmpIPredicate::eq: {
678e1e0ecb9SLei Zhang rewriter.replaceOpWithNewOp<spirv::LogicalEqualOp>(op, adaptor.getLhs(),
679e1e0ecb9SLei Zhang adaptor.getRhs());
680e1e0ecb9SLei Zhang return success();
681b5192cbeSLei Zhang }
682e1e0ecb9SLei Zhang case arith::CmpIPredicate::ne: {
683e1e0ecb9SLei Zhang rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, adaptor.getLhs(),
684e1e0ecb9SLei Zhang adaptor.getRhs());
685e1e0ecb9SLei Zhang return success();
686e1e0ecb9SLei Zhang }
687e1e0ecb9SLei Zhang case arith::CmpIPredicate::uge:
688e1e0ecb9SLei Zhang case arith::CmpIPredicate::ugt:
689e1e0ecb9SLei Zhang case arith::CmpIPredicate::ule:
690e1e0ecb9SLei Zhang case arith::CmpIPredicate::ult: {
691e1e0ecb9SLei Zhang // There are no direct corresponding instructions in SPIR-V for such cases.
692e1e0ecb9SLei Zhang // Extend them to 32-bit and do comparision then.
693e1e0ecb9SLei Zhang Type type = rewriter.getI32Type();
694e1e0ecb9SLei Zhang if (auto vectorType = dstType.dyn_cast<VectorType>())
695e1e0ecb9SLei Zhang type = VectorType::get(vectorType.getShape(), type);
696e1e0ecb9SLei Zhang auto extLhs =
697e1e0ecb9SLei Zhang rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getLhs());
698e1e0ecb9SLei Zhang auto extRhs =
699e1e0ecb9SLei Zhang rewriter.create<arith::ExtUIOp>(op.getLoc(), type, adaptor.getRhs());
700a54f4eaeSMogball
701e1e0ecb9SLei Zhang rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), extLhs,
702e1e0ecb9SLei Zhang extRhs);
703e1e0ecb9SLei Zhang return success();
704e1e0ecb9SLei Zhang }
705e1e0ecb9SLei Zhang default:
706e1e0ecb9SLei Zhang break;
707a54f4eaeSMogball }
708a54f4eaeSMogball return failure();
709a54f4eaeSMogball }
710a54f4eaeSMogball
711a54f4eaeSMogball //===----------------------------------------------------------------------===//
712a54f4eaeSMogball // CmpIOpPattern
713a54f4eaeSMogball //===----------------------------------------------------------------------===//
714a54f4eaeSMogball
715a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::CmpIOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const716a54f4eaeSMogball CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
717a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
718b5192cbeSLei Zhang Type srcType = op.getLhs().getType();
719b5192cbeSLei Zhang if (isBoolScalarOrVector(srcType))
720b5192cbeSLei Zhang return failure();
721b5192cbeSLei Zhang Type dstType = getTypeConverter()->convertType(srcType);
722b5192cbeSLei Zhang if (!dstType)
723a54f4eaeSMogball return failure();
724a54f4eaeSMogball
725a54f4eaeSMogball switch (op.getPredicate()) {
726a54f4eaeSMogball #define DISPATCH(cmpPredicate, spirvOp) \
727a54f4eaeSMogball case cmpPredicate: \
728a54f4eaeSMogball if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() && \
729b5192cbeSLei Zhang srcType != dstType && !hasSameBitwidth(srcType, dstType)) { \
730a54f4eaeSMogball return op.emitError( \
731a54f4eaeSMogball "bitwidth emulation is not implemented yet on unsigned op"); \
732a54f4eaeSMogball } \
733b5192cbeSLei Zhang rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
734b5192cbeSLei Zhang adaptor.getRhs()); \
735a54f4eaeSMogball return success();
736a54f4eaeSMogball
737a54f4eaeSMogball DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
738a54f4eaeSMogball DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
739a54f4eaeSMogball DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
740a54f4eaeSMogball DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
741a54f4eaeSMogball DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
742a54f4eaeSMogball DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
743a54f4eaeSMogball DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
744a54f4eaeSMogball DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
745a54f4eaeSMogball DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
746a54f4eaeSMogball DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
747a54f4eaeSMogball
748a54f4eaeSMogball #undef DISPATCH
749a54f4eaeSMogball }
750a54f4eaeSMogball return failure();
751a54f4eaeSMogball }
752a54f4eaeSMogball
753a54f4eaeSMogball //===----------------------------------------------------------------------===//
754a54f4eaeSMogball // CmpFOpPattern
755a54f4eaeSMogball //===----------------------------------------------------------------------===//
756a54f4eaeSMogball
757a54f4eaeSMogball LogicalResult
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const758a54f4eaeSMogball CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
759a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
760a54f4eaeSMogball switch (op.getPredicate()) {
761a54f4eaeSMogball #define DISPATCH(cmpPredicate, spirvOp) \
762a54f4eaeSMogball case cmpPredicate: \
763b5192cbeSLei Zhang rewriter.replaceOpWithNewOp<spirvOp>(op, adaptor.getLhs(), \
764b5192cbeSLei Zhang adaptor.getRhs()); \
765a54f4eaeSMogball return success();
766a54f4eaeSMogball
767a54f4eaeSMogball // Ordered.
768a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
769a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
770a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
771a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
772a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
773a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
774a54f4eaeSMogball // Unordered.
775a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
776a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
777a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
778a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
779a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
780a54f4eaeSMogball DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
781a54f4eaeSMogball
782a54f4eaeSMogball #undef DISPATCH
783a54f4eaeSMogball
784a54f4eaeSMogball default:
785a54f4eaeSMogball break;
786a54f4eaeSMogball }
787a54f4eaeSMogball return failure();
788a54f4eaeSMogball }
789a54f4eaeSMogball
790a54f4eaeSMogball //===----------------------------------------------------------------------===//
791a54f4eaeSMogball // CmpFOpNanKernelPattern
792a54f4eaeSMogball //===----------------------------------------------------------------------===//
793a54f4eaeSMogball
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const794a54f4eaeSMogball LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
795a54f4eaeSMogball arith::CmpFOp op, OpAdaptor adaptor,
796a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
797a54f4eaeSMogball if (op.getPredicate() == arith::CmpFPredicate::ORD) {
798cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
799cfb72fd3SJacques Pienaar adaptor.getRhs());
800a54f4eaeSMogball return success();
801a54f4eaeSMogball }
802a54f4eaeSMogball
803a54f4eaeSMogball if (op.getPredicate() == arith::CmpFPredicate::UNO) {
804cfb72fd3SJacques Pienaar rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
805cfb72fd3SJacques Pienaar adaptor.getRhs());
806a54f4eaeSMogball return success();
807a54f4eaeSMogball }
808a54f4eaeSMogball
809a54f4eaeSMogball return failure();
810a54f4eaeSMogball }
811a54f4eaeSMogball
812a54f4eaeSMogball //===----------------------------------------------------------------------===//
813a54f4eaeSMogball // CmpFOpNanNonePattern
814a54f4eaeSMogball //===----------------------------------------------------------------------===//
815a54f4eaeSMogball
matchAndRewrite(arith::CmpFOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const816a54f4eaeSMogball LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
817a54f4eaeSMogball arith::CmpFOp op, OpAdaptor adaptor,
818a54f4eaeSMogball ConversionPatternRewriter &rewriter) const {
819a54f4eaeSMogball if (op.getPredicate() != arith::CmpFPredicate::ORD &&
820a54f4eaeSMogball op.getPredicate() != arith::CmpFPredicate::UNO)
821a54f4eaeSMogball return failure();
822a54f4eaeSMogball
823a54f4eaeSMogball Location loc = op.getLoc();
824a54f4eaeSMogball
825cfb72fd3SJacques Pienaar Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
826cfb72fd3SJacques Pienaar Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
827a54f4eaeSMogball
828a54f4eaeSMogball Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
829a54f4eaeSMogball if (op.getPredicate() == arith::CmpFPredicate::ORD)
830a54f4eaeSMogball replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
831a54f4eaeSMogball
832a54f4eaeSMogball rewriter.replaceOp(op, replace);
833a54f4eaeSMogball return success();
834a54f4eaeSMogball }
835a54f4eaeSMogball
836a54f4eaeSMogball //===----------------------------------------------------------------------===//
837dec8af70SRiver Riddle // SelectOpPattern
838dec8af70SRiver Riddle //===----------------------------------------------------------------------===//
839dec8af70SRiver Riddle
840dec8af70SRiver Riddle LogicalResult
matchAndRewrite(arith::SelectOp op,OpAdaptor adaptor,ConversionPatternRewriter & rewriter) const841dec8af70SRiver Riddle SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
842dec8af70SRiver Riddle ConversionPatternRewriter &rewriter) const {
843dec8af70SRiver Riddle rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
844dec8af70SRiver Riddle adaptor.getTrueValue(),
845dec8af70SRiver Riddle adaptor.getFalseValue());
846dec8af70SRiver Riddle return success();
847dec8af70SRiver Riddle }
848dec8af70SRiver Riddle
849dec8af70SRiver Riddle //===----------------------------------------------------------------------===//
850a54f4eaeSMogball // Pattern Population
851a54f4eaeSMogball //===----------------------------------------------------------------------===//
852a54f4eaeSMogball
populateArithmeticToSPIRVPatterns(SPIRVTypeConverter & typeConverter,RewritePatternSet & patterns)853a54f4eaeSMogball void mlir::arith::populateArithmeticToSPIRVPatterns(
854a54f4eaeSMogball SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
855a54f4eaeSMogball // clang-format off
856a54f4eaeSMogball patterns.add<
857a54f4eaeSMogball ConstantCompositeOpPattern,
858a54f4eaeSMogball ConstantScalarOpPattern,
859d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
860d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
861d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
862d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
863d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
864d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
865*52b630daSJakub Kuderski RemSIOpGLPattern, RemSIOpCLPattern,
866a54f4eaeSMogball BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
867a54f4eaeSMogball BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
868a54f4eaeSMogball XOrIOpLogicalPattern, XOrIOpBooleanPattern,
869d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
870d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
871d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
872d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
873d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
874d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
875d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
876d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
877d9edc1a5SThomas Raoux spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
878a54f4eaeSMogball TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
879a54f4eaeSMogball TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
880a54f4eaeSMogball TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
881a54f4eaeSMogball TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
882a54f4eaeSMogball TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
883a54f4eaeSMogball TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
884a54f4eaeSMogball TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
885a54f4eaeSMogball TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
886a54f4eaeSMogball TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
8876e2c0e69Sxndcn TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
888a54f4eaeSMogball CmpIOpBooleanPattern, CmpIOpPattern,
889dec8af70SRiver Riddle CmpFOpNanNonePattern, CmpFOpPattern,
8903ba66435SRiver Riddle SelectOpPattern,
8913ba66435SRiver Riddle
892*52b630daSJakub Kuderski spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLFMaxOp>,
893*52b630daSJakub Kuderski spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSMaxOp>,
894*52b630daSJakub Kuderski spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLUMaxOp>,
895*52b630daSJakub Kuderski spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLFMinOp>,
896*52b630daSJakub Kuderski spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSMinOp>,
897*52b630daSJakub Kuderski spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLUMinOp>
898a54f4eaeSMogball >(typeConverter, patterns.getContext());
899a54f4eaeSMogball // clang-format on
900a54f4eaeSMogball
901a54f4eaeSMogball // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
902a54f4eaeSMogball // capability is available.
903a54f4eaeSMogball patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
904a54f4eaeSMogball /*benefit=*/2);
905a54f4eaeSMogball }
906a54f4eaeSMogball
907a54f4eaeSMogball //===----------------------------------------------------------------------===//
908a54f4eaeSMogball // Pass Definition
909a54f4eaeSMogball //===----------------------------------------------------------------------===//
910a54f4eaeSMogball
911a54f4eaeSMogball namespace {
912a54f4eaeSMogball struct ConvertArithmeticToSPIRVPass
913a54f4eaeSMogball : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
runOnOperation__anon52142abf0311::ConvertArithmeticToSPIRVPass91441574554SRiver Riddle void runOnOperation() override {
9153ba66435SRiver Riddle auto module = getOperation();
916a54f4eaeSMogball auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
917a54f4eaeSMogball auto target = SPIRVConversionTarget::get(targetAttr);
918a54f4eaeSMogball
919a54f4eaeSMogball SPIRVTypeConverter::Options options;
920a54f4eaeSMogball options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
921a54f4eaeSMogball SPIRVTypeConverter typeConverter(targetAttr, options);
922a54f4eaeSMogball
92391de20c3SLei Zhang // Use UnrealizedConversionCast as the bridge so that we don't need to pull
92491de20c3SLei Zhang // in patterns for other dialects.
92591de20c3SLei Zhang auto addUnrealizedCast = [](OpBuilder &builder, Type type,
92691de20c3SLei Zhang ValueRange inputs, Location loc) {
92791de20c3SLei Zhang auto cast = builder.create<UnrealizedConversionCastOp>(loc, type, inputs);
92891de20c3SLei Zhang return Optional<Value>(cast.getResult(0));
92991de20c3SLei Zhang };
93091de20c3SLei Zhang typeConverter.addSourceMaterialization(addUnrealizedCast);
93191de20c3SLei Zhang typeConverter.addTargetMaterialization(addUnrealizedCast);
93291de20c3SLei Zhang target->addLegalOp<UnrealizedConversionCastOp>();
93391de20c3SLei Zhang
934a54f4eaeSMogball RewritePatternSet patterns(&getContext());
9353ba66435SRiver Riddle arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
936a54f4eaeSMogball
9373ba66435SRiver Riddle if (failed(applyPartialConversion(module, *target, std::move(patterns))))
938a54f4eaeSMogball signalPassFailure();
939a54f4eaeSMogball }
940a54f4eaeSMogball };
941be0a7e9fSMehdi Amini } // namespace
942a54f4eaeSMogball
createConvertArithmeticToSPIRVPass()943a54f4eaeSMogball std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
944a54f4eaeSMogball return std::make_unique<ConvertArithmeticToSPIRVPass>();
945a54f4eaeSMogball }
946