1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
10 #include "../PassDetail.h"
11 #include "../SPIRVCommon/Pattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "arith-to-spirv-pattern"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Operation Conversion
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 
28 /// Converts composite arith.constant operation to spv.Constant.
29 struct ConstantCompositeOpPattern final
30     : public OpConversionPattern<arith::ConstantOp> {
31   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
32 
33   LogicalResult
34   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
35                   ConversionPatternRewriter &rewriter) const override;
36 };
37 
38 /// Converts scalar arith.constant operation to spv.Constant.
39 struct ConstantScalarOpPattern final
40     : public OpConversionPattern<arith::ConstantOp> {
41   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
42 
43   LogicalResult
44   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
45                   ConversionPatternRewriter &rewriter) const override;
46 };
47 
48 /// Converts arith.remsi to GLSL SPIR-V ops.
49 ///
50 /// This cannot be merged into the template unary/binary pattern due to Vulkan
51 /// restrictions over spv.SRem and spv.SMod.
52 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> {
53   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
54 
55   LogicalResult
56   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
57                   ConversionPatternRewriter &rewriter) const override;
58 };
59 
60 /// Converts arith.remsi to OpenCL SPIR-V ops.
61 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> {
62   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
63 
64   LogicalResult
65   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
66                   ConversionPatternRewriter &rewriter) const override;
67 };
68 
69 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
70 /// other than the BinaryOpPatternPattern because if the operands are boolean
71 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
72 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
73 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
74 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
75   using OpConversionPattern<Op>::OpConversionPattern;
76 
77   LogicalResult
78   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
79                   ConversionPatternRewriter &rewriter) const override;
80 };
81 
82 /// Converts arith.xori to SPIR-V operations.
83 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
84   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
85 
86   LogicalResult
87   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
88                   ConversionPatternRewriter &rewriter) const override;
89 };
90 
91 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
92 /// vector of i1.
93 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
94   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
95 
96   LogicalResult
97   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
98                   ConversionPatternRewriter &rewriter) const override;
99 };
100 
101 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
102 /// i1.
103 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
104   using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
105 
106   LogicalResult
107   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const override;
109 };
110 
111 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
112 /// i1.
113 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
114   using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
115 
116   LogicalResult
117   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
118                   ConversionPatternRewriter &rewriter) const override;
119 };
120 
121 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
122 /// i1.
123 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
124   using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
125 
126   LogicalResult
127   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
128                   ConversionPatternRewriter &rewriter) const override;
129 };
130 
131 /// Converts type-casting standard operations to SPIR-V operations.
132 template <typename Op, typename SPIRVOp>
133 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
134   using OpConversionPattern<Op>::OpConversionPattern;
135 
136   LogicalResult
137   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
138                   ConversionPatternRewriter &rewriter) const override;
139 };
140 
141 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
142 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
143 public:
144   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
145 
146   LogicalResult
147   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
148                   ConversionPatternRewriter &rewriter) const override;
149 };
150 
151 /// Converts integer compare operation to SPIR-V ops.
152 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
153 public:
154   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
155 
156   LogicalResult
157   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
158                   ConversionPatternRewriter &rewriter) const override;
159 };
160 
161 /// Converts floating-point comparison operations to SPIR-V ops.
162 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
163 public:
164   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
165 
166   LogicalResult
167   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
168                   ConversionPatternRewriter &rewriter) const override;
169 };
170 
171 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
172 /// Kernel capability.
173 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
174 public:
175   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
176 
177   LogicalResult
178   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
179                   ConversionPatternRewriter &rewriter) const override;
180 };
181 
182 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
183 /// require additional capability.
184 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
185 public:
186   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
187 
188   LogicalResult
189   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
190                   ConversionPatternRewriter &rewriter) const override;
191 };
192 
193 } // namespace
194 
195 //===----------------------------------------------------------------------===//
196 // Conversion Helpers
197 //===----------------------------------------------------------------------===//
198 
199 /// Converts the given `srcAttr` into a boolean attribute if it holds an
200 /// integral value. Returns null attribute if conversion fails.
201 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
202   if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
203     return boolAttr;
204   if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
205     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
206   return BoolAttr();
207 }
208 
209 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
210 /// Returns null attribute if conversion fails.
211 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
212                                       Builder builder) {
213   // If the source number uses less active bits than the target bitwidth, then
214   // it should be safe to convert.
215   if (srcAttr.getValue().isIntN(dstType.getWidth()))
216     return builder.getIntegerAttr(dstType, srcAttr.getInt());
217 
218   // XXX: Try again by interpreting the source number as a signed value.
219   // Although integers in the standard dialect are signless, they can represent
220   // a signed number. It's the operation decides how to interpret. This is
221   // dangerous, but it seems there is no good way of handling this if we still
222   // want to change the bitwidth. Emit a message at least.
223   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
224     auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
225     LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
226                             << dstAttr << "' for type '" << dstType << "'\n");
227     return dstAttr;
228   }
229 
230   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
231                           << "' illegal: cannot fit into target type '"
232                           << dstType << "'\n");
233   return IntegerAttr();
234 }
235 
236 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
237 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
238 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
239                                   Builder builder) {
240   // Only support converting to float for now.
241   if (!dstType.isF32())
242     return FloatAttr();
243 
244   // Try to convert the source floating-point number to single precision.
245   APFloat dstVal = srcAttr.getValue();
246   bool losesInfo = false;
247   APFloat::opStatus status =
248       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
249   if (status != APFloat::opOK || losesInfo) {
250     LLVM_DEBUG(llvm::dbgs()
251                << srcAttr << " illegal: cannot fit into converted type '"
252                << dstType << "'\n");
253     return FloatAttr();
254   }
255 
256   return builder.getF32FloatAttr(dstVal.convertToFloat());
257 }
258 
259 /// Returns true if the given `type` is a boolean scalar or vector type.
260 static bool isBoolScalarOrVector(Type type) {
261   if (type.isInteger(1))
262     return true;
263   if (auto vecType = type.dyn_cast<VectorType>())
264     return vecType.getElementType().isInteger(1);
265   return false;
266 }
267 
268 //===----------------------------------------------------------------------===//
269 // ConstantOp with composite type
270 //===----------------------------------------------------------------------===//
271 
272 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
273     arith::ConstantOp constOp, OpAdaptor adaptor,
274     ConversionPatternRewriter &rewriter) const {
275   auto srcType = constOp.getType().dyn_cast<ShapedType>();
276   if (!srcType || srcType.getNumElements() == 1)
277     return failure();
278 
279   // arith.constant should only have vector or tenor types.
280   assert((srcType.isa<VectorType, RankedTensorType>()));
281 
282   auto dstType = getTypeConverter()->convertType(srcType);
283   if (!dstType)
284     return failure();
285 
286   auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
287   ShapedType dstAttrType = dstElementsAttr.getType();
288   if (!dstElementsAttr)
289     return failure();
290 
291   // If the composite type has more than one dimensions, perform linearization.
292   if (srcType.getRank() > 1) {
293     if (srcType.isa<RankedTensorType>()) {
294       dstAttrType = RankedTensorType::get(srcType.getNumElements(),
295                                           srcType.getElementType());
296       dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
297     } else {
298       // TODO: add support for large vectors.
299       return failure();
300     }
301   }
302 
303   Type srcElemType = srcType.getElementType();
304   Type dstElemType;
305   // Tensor types are converted to SPIR-V array types; vector types are
306   // converted to SPIR-V vector/array types.
307   if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
308     dstElemType = arrayType.getElementType();
309   else
310     dstElemType = dstType.cast<VectorType>().getElementType();
311 
312   // If the source and destination element types are different, perform
313   // attribute conversion.
314   if (srcElemType != dstElemType) {
315     SmallVector<Attribute, 8> elements;
316     if (srcElemType.isa<FloatType>()) {
317       for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
318         FloatAttr dstAttr =
319             convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
320         if (!dstAttr)
321           return failure();
322         elements.push_back(dstAttr);
323       }
324     } else if (srcElemType.isInteger(1)) {
325       return failure();
326     } else {
327       for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
328         IntegerAttr dstAttr = convertIntegerAttr(
329             srcAttr, dstElemType.cast<IntegerType>(), rewriter);
330         if (!dstAttr)
331           return failure();
332         elements.push_back(dstAttr);
333       }
334     }
335 
336     // Unfortunately, we cannot use dialect-specific types for element
337     // attributes; element attributes only works with builtin types. So we need
338     // to prepare another converted builtin types for the destination elements
339     // attribute.
340     if (dstAttrType.isa<RankedTensorType>())
341       dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
342     else
343       dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
344 
345     dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
346   }
347 
348   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
349                                                  dstElementsAttr);
350   return success();
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // ConstantOp with scalar type
355 //===----------------------------------------------------------------------===//
356 
357 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
358     arith::ConstantOp constOp, OpAdaptor adaptor,
359     ConversionPatternRewriter &rewriter) const {
360   Type srcType = constOp.getType();
361   if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
362     if (shapedType.getNumElements() != 1)
363       return failure();
364     srcType = shapedType.getElementType();
365   }
366   if (!srcType.isIntOrIndexOrFloat())
367     return failure();
368 
369   Attribute cstAttr = constOp.getValue();
370   if (cstAttr.getType().isa<ShapedType>())
371     cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
372 
373   Type dstType = getTypeConverter()->convertType(srcType);
374   if (!dstType)
375     return failure();
376 
377   // Floating-point types.
378   if (srcType.isa<FloatType>()) {
379     auto srcAttr = cstAttr.cast<FloatAttr>();
380     auto dstAttr = srcAttr;
381 
382     // Floating-point types not supported in the target environment are all
383     // converted to float type.
384     if (srcType != dstType) {
385       dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
386       if (!dstAttr)
387         return failure();
388     }
389 
390     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
391     return success();
392   }
393 
394   // Bool type.
395   if (srcType.isInteger(1)) {
396     // arith.constant can use 0/1 instead of true/false for i1 values. We need
397     // to handle that here.
398     auto dstAttr = convertBoolAttr(cstAttr, rewriter);
399     if (!dstAttr)
400       return failure();
401     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
402     return success();
403   }
404 
405   // IndexType or IntegerType. Index values are converted to 32-bit integer
406   // values when converting to SPIR-V.
407   auto srcAttr = cstAttr.cast<IntegerAttr>();
408   auto dstAttr =
409       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
410   if (!dstAttr)
411     return failure();
412   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
413   return success();
414 }
415 
416 //===----------------------------------------------------------------------===//
417 // RemSIOpGLSLPattern
418 //===----------------------------------------------------------------------===//
419 
420 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
421 /// the sign of `signOperand`.
422 ///
423 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
424 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
425 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
426 /// if either operand can be negative. Emulate it via spv.UMod.
427 template <typename SignedAbsOp>
428 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
429                                     Value signOperand, OpBuilder &builder) {
430   assert(lhs.getType() == rhs.getType());
431   assert(lhs == signOperand || rhs == signOperand);
432 
433   Type type = lhs.getType();
434 
435   // Calculate the remainder with spv.UMod.
436   Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
437   Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
438   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
439 
440   // Fix the sign.
441   Value isPositive;
442   if (lhs == signOperand)
443     isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
444   else
445     isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
446   Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
447   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
448 }
449 
450 LogicalResult
451 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
452                                     ConversionPatternRewriter &rewriter) const {
453   Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>(
454       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
455       adaptor.getOperands()[0], rewriter);
456   rewriter.replaceOp(op, result);
457 
458   return success();
459 }
460 
461 //===----------------------------------------------------------------------===//
462 // RemSIOpOCLPattern
463 //===----------------------------------------------------------------------===//
464 
465 LogicalResult
466 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
467                                    ConversionPatternRewriter &rewriter) const {
468   Value result = emulateSignedRemainder<spirv::OCLSAbsOp>(
469       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
470       adaptor.getOperands()[0], rewriter);
471   rewriter.replaceOp(op, result);
472 
473   return success();
474 }
475 
476 //===----------------------------------------------------------------------===//
477 // BitwiseOpPattern
478 //===----------------------------------------------------------------------===//
479 
480 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
481 LogicalResult
482 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
483     Op op, typename Op::Adaptor adaptor,
484     ConversionPatternRewriter &rewriter) const {
485   assert(adaptor.getOperands().size() == 2);
486   auto dstType =
487       this->getTypeConverter()->convertType(op.getResult().getType());
488   if (!dstType)
489     return failure();
490   if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
491     rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
492                                                          adaptor.getOperands());
493   } else {
494     rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
495                                                          adaptor.getOperands());
496   }
497   return success();
498 }
499 
500 //===----------------------------------------------------------------------===//
501 // XOrIOpLogicalPattern
502 //===----------------------------------------------------------------------===//
503 
504 LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
505     arith::XOrIOp op, OpAdaptor adaptor,
506     ConversionPatternRewriter &rewriter) const {
507   assert(adaptor.getOperands().size() == 2);
508 
509   if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
510     return failure();
511 
512   auto dstType = getTypeConverter()->convertType(op.getType());
513   if (!dstType)
514     return failure();
515   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
516                                                    adaptor.getOperands());
517 
518   return success();
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // XOrIOpBooleanPattern
523 //===----------------------------------------------------------------------===//
524 
525 LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
526     arith::XOrIOp op, OpAdaptor adaptor,
527     ConversionPatternRewriter &rewriter) const {
528   assert(adaptor.getOperands().size() == 2);
529 
530   if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
531     return failure();
532 
533   auto dstType = getTypeConverter()->convertType(op.getType());
534   if (!dstType)
535     return failure();
536   rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
537                                                         adaptor.getOperands());
538   return success();
539 }
540 
541 //===----------------------------------------------------------------------===//
542 // UIToFPI1Pattern
543 //===----------------------------------------------------------------------===//
544 
545 LogicalResult
546 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
547                                  ConversionPatternRewriter &rewriter) const {
548   auto srcType = adaptor.getOperands().front().getType();
549   if (!isBoolScalarOrVector(srcType))
550     return failure();
551 
552   auto dstType =
553       this->getTypeConverter()->convertType(op.getResult().getType());
554   Location loc = op.getLoc();
555   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
556   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
557   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
558       op, dstType, adaptor.getOperands().front(), one, zero);
559   return success();
560 }
561 
562 //===----------------------------------------------------------------------===//
563 // ExtUII1Pattern
564 //===----------------------------------------------------------------------===//
565 
566 LogicalResult
567 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
568                                 ConversionPatternRewriter &rewriter) const {
569   auto srcType = adaptor.getOperands().front().getType();
570   if (!isBoolScalarOrVector(srcType))
571     return failure();
572 
573   auto dstType =
574       this->getTypeConverter()->convertType(op.getResult().getType());
575   Location loc = op.getLoc();
576   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
577   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
578   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
579       op, dstType, adaptor.getOperands().front(), one, zero);
580   return success();
581 }
582 
583 //===----------------------------------------------------------------------===//
584 // TruncII1Pattern
585 //===----------------------------------------------------------------------===//
586 
587 LogicalResult
588 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
589                                  ConversionPatternRewriter &rewriter) const {
590   auto dstType =
591       this->getTypeConverter()->convertType(op.getResult().getType());
592   if (!isBoolScalarOrVector(dstType))
593     return failure();
594 
595   Location loc = op.getLoc();
596   auto srcType = adaptor.getOperands().front().getType();
597   // Check if (x & 1) == 1.
598   Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
599   Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
600       loc, srcType, adaptor.getOperands()[0], mask);
601   Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
602 
603   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
604   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
605   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
606   return success();
607 }
608 
609 //===----------------------------------------------------------------------===//
610 // TypeCastingOpPattern
611 //===----------------------------------------------------------------------===//
612 
613 template <typename Op, typename SPIRVOp>
614 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
615     Op op, typename Op::Adaptor adaptor,
616     ConversionPatternRewriter &rewriter) const {
617   assert(adaptor.getOperands().size() == 1);
618   auto srcType = adaptor.getOperands().front().getType();
619   auto dstType =
620       this->getTypeConverter()->convertType(op.getResult().getType());
621   if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
622     return failure();
623   if (dstType == srcType) {
624     // Due to type conversion, we are seeing the same source and target type.
625     // Then we can just erase this operation by forwarding its operand.
626     rewriter.replaceOp(op, adaptor.getOperands().front());
627   } else {
628     rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
629                                                   adaptor.getOperands());
630   }
631   return success();
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // CmpIOpBooleanPattern
636 //===----------------------------------------------------------------------===//
637 
638 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
639     arith::CmpIOp op, OpAdaptor adaptor,
640     ConversionPatternRewriter &rewriter) const {
641   Type operandType = op.getLhs().getType();
642   if (!isBoolScalarOrVector(operandType))
643     return failure();
644 
645   switch (op.getPredicate()) {
646 #define DISPATCH(cmpPredicate, spirvOp)                                        \
647   case cmpPredicate:                                                           \
648     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
649                                          adaptor.getLhs(), adaptor.getRhs());  \
650     return success();
651 
652     DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
653     DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
654 
655 #undef DISPATCH
656   default:;
657   }
658   return failure();
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // CmpIOpPattern
663 //===----------------------------------------------------------------------===//
664 
665 LogicalResult
666 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
667                                ConversionPatternRewriter &rewriter) const {
668   Type operandType = op.getLhs().getType();
669   if (isBoolScalarOrVector(operandType))
670     return failure();
671 
672   switch (op.getPredicate()) {
673 #define DISPATCH(cmpPredicate, spirvOp)                                        \
674   case cmpPredicate:                                                           \
675     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
676         operandType != this->getTypeConverter()->convertType(operandType)) {   \
677       return op.emitError(                                                     \
678           "bitwidth emulation is not implemented yet on unsigned op");         \
679     }                                                                          \
680     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
681                                          adaptor.getLhs(), adaptor.getRhs());  \
682     return success();
683 
684     DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
685     DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
686     DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
687     DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
688     DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
689     DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
690     DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
691     DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
692     DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
693     DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
694 
695 #undef DISPATCH
696   }
697   return failure();
698 }
699 
700 //===----------------------------------------------------------------------===//
701 // CmpFOpPattern
702 //===----------------------------------------------------------------------===//
703 
704 LogicalResult
705 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
706                                ConversionPatternRewriter &rewriter) const {
707   switch (op.getPredicate()) {
708 #define DISPATCH(cmpPredicate, spirvOp)                                        \
709   case cmpPredicate:                                                           \
710     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
711                                          adaptor.getLhs(), adaptor.getRhs());  \
712     return success();
713 
714     // Ordered.
715     DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
716     DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
717     DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
718     DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
719     DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
720     DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
721     // Unordered.
722     DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
723     DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
724     DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
725     DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
726     DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
727     DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
728 
729 #undef DISPATCH
730 
731   default:
732     break;
733   }
734   return failure();
735 }
736 
737 //===----------------------------------------------------------------------===//
738 // CmpFOpNanKernelPattern
739 //===----------------------------------------------------------------------===//
740 
741 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
742     arith::CmpFOp op, OpAdaptor adaptor,
743     ConversionPatternRewriter &rewriter) const {
744   if (op.getPredicate() == arith::CmpFPredicate::ORD) {
745     rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
746                                                   adaptor.getRhs());
747     return success();
748   }
749 
750   if (op.getPredicate() == arith::CmpFPredicate::UNO) {
751     rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
752                                                     adaptor.getRhs());
753     return success();
754   }
755 
756   return failure();
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // CmpFOpNanNonePattern
761 //===----------------------------------------------------------------------===//
762 
763 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
764     arith::CmpFOp op, OpAdaptor adaptor,
765     ConversionPatternRewriter &rewriter) const {
766   if (op.getPredicate() != arith::CmpFPredicate::ORD &&
767       op.getPredicate() != arith::CmpFPredicate::UNO)
768     return failure();
769 
770   Location loc = op.getLoc();
771 
772   Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
773   Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
774 
775   Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
776   if (op.getPredicate() == arith::CmpFPredicate::ORD)
777     replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
778 
779   rewriter.replaceOp(op, replace);
780   return success();
781 }
782 
783 //===----------------------------------------------------------------------===//
784 // Pattern Population
785 //===----------------------------------------------------------------------===//
786 
787 void mlir::arith::populateArithmeticToSPIRVPatterns(
788     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
789   // clang-format off
790   patterns.add<
791     ConstantCompositeOpPattern,
792     ConstantScalarOpPattern,
793     spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
794     spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
795     spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
796     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
797     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
798     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
799     RemSIOpGLSLPattern, RemSIOpOCLPattern,
800     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
801     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
802     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
803     spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
804     spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
805     spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
806     spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
807     spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
808     spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
809     spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
810     spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
811     spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
812     TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
813     TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
814     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
815     TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
816     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
817     TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
818     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
819     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
820     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
821     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
822     CmpIOpBooleanPattern, CmpIOpPattern,
823     CmpFOpNanNonePattern, CmpFOpPattern
824   >(typeConverter, patterns.getContext());
825   // clang-format on
826 
827   // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
828   // capability is available.
829   patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
830                                        /*benefit=*/2);
831 }
832 
833 //===----------------------------------------------------------------------===//
834 // Pass Definition
835 //===----------------------------------------------------------------------===//
836 
837 namespace {
838 struct ConvertArithmeticToSPIRVPass
839     : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
840   void runOnOperation() override {
841     auto module = getOperation()->getParentOfType<ModuleOp>();
842     auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
843     auto target = SPIRVConversionTarget::get(targetAttr);
844 
845     SPIRVTypeConverter::Options options;
846     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
847     SPIRVTypeConverter typeConverter(targetAttr, options);
848 
849     RewritePatternSet patterns(&getContext());
850     mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
851 
852     if (failed(applyPartialConversion(getOperation(), *target,
853                                       std::move(patterns))))
854       signalPassFailure();
855   }
856 };
857 } // namespace
858 
859 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
860   return std::make_unique<ConvertArithmeticToSPIRVPass>();
861 }
862