1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
10 #include "../PassDetail.h"
11 #include "../SPIRVCommon/Pattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "arith-to-spirv-pattern"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Operation Conversion
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 
28 /// Converts composite arith.constant operation to spv.Constant.
29 struct ConstantCompositeOpPattern final
30     : public OpConversionPattern<arith::ConstantOp> {
31   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
32 
33   LogicalResult
34   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
35                   ConversionPatternRewriter &rewriter) const override;
36 };
37 
38 /// Converts scalar arith.constant operation to spv.Constant.
39 struct ConstantScalarOpPattern final
40     : public OpConversionPattern<arith::ConstantOp> {
41   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
42 
43   LogicalResult
44   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
45                   ConversionPatternRewriter &rewriter) const override;
46 };
47 
48 /// Converts arith.remsi to SPIR-V ops.
49 ///
50 /// This cannot be merged into the template unary/binary pattern due to Vulkan
51 /// restrictions over spv.SRem and spv.SMod.
52 struct RemSIOpPattern final : public OpConversionPattern<arith::RemSIOp> {
53   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
54 
55   LogicalResult
56   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
57                   ConversionPatternRewriter &rewriter) const override;
58 };
59 
60 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
61 /// other than the BinaryOpPatternPattern because if the operands are boolean
62 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
63 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
64 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
65 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
66   using OpConversionPattern<Op>::OpConversionPattern;
67 
68   LogicalResult
69   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
70                   ConversionPatternRewriter &rewriter) const override;
71 };
72 
73 /// Converts arith.xori to SPIR-V operations.
74 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
75   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
76 
77   LogicalResult
78   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
79                   ConversionPatternRewriter &rewriter) const override;
80 };
81 
82 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
83 /// vector of i1.
84 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
85   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
86 
87   LogicalResult
88   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
89                   ConversionPatternRewriter &rewriter) const override;
90 };
91 
92 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
93 /// i1.
94 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
95   using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
96 
97   LogicalResult
98   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
99                   ConversionPatternRewriter &rewriter) const override;
100 };
101 
102 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
103 /// i1.
104 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
105   using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
106 
107   LogicalResult
108   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
109                   ConversionPatternRewriter &rewriter) const override;
110 };
111 
112 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
113 /// i1.
114 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
115   using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
116 
117   LogicalResult
118   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
119                   ConversionPatternRewriter &rewriter) const override;
120 };
121 
122 /// Converts type-casting standard operations to SPIR-V operations.
123 template <typename Op, typename SPIRVOp>
124 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
125   using OpConversionPattern<Op>::OpConversionPattern;
126 
127   LogicalResult
128   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
129                   ConversionPatternRewriter &rewriter) const override;
130 };
131 
132 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
133 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
134 public:
135   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
136 
137   LogicalResult
138   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
139                   ConversionPatternRewriter &rewriter) const override;
140 };
141 
142 /// Converts integer compare operation to SPIR-V ops.
143 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
144 public:
145   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
146 
147   LogicalResult
148   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
149                   ConversionPatternRewriter &rewriter) const override;
150 };
151 
152 /// Converts floating-point comparison operations to SPIR-V ops.
153 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
154 public:
155   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
156 
157   LogicalResult
158   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
159                   ConversionPatternRewriter &rewriter) const override;
160 };
161 
162 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
163 /// Kernel capability.
164 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
165 public:
166   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
167 
168   LogicalResult
169   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
170                   ConversionPatternRewriter &rewriter) const override;
171 };
172 
173 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
174 /// require additional capability.
175 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
176 public:
177   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
178 
179   LogicalResult
180   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
181                   ConversionPatternRewriter &rewriter) const override;
182 };
183 
184 } // end anonymous namespace
185 
186 //===----------------------------------------------------------------------===//
187 // Conversion Helpers
188 //===----------------------------------------------------------------------===//
189 
190 /// Converts the given `srcAttr` into a boolean attribute if it holds an
191 /// integral value. Returns null attribute if conversion fails.
192 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
193   if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
194     return boolAttr;
195   if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
196     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
197   return BoolAttr();
198 }
199 
200 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
201 /// Returns null attribute if conversion fails.
202 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
203                                       Builder builder) {
204   // If the source number uses less active bits than the target bitwidth, then
205   // it should be safe to convert.
206   if (srcAttr.getValue().isIntN(dstType.getWidth()))
207     return builder.getIntegerAttr(dstType, srcAttr.getInt());
208 
209   // XXX: Try again by interpreting the source number as a signed value.
210   // Although integers in the standard dialect are signless, they can represent
211   // a signed number. It's the operation decides how to interpret. This is
212   // dangerous, but it seems there is no good way of handling this if we still
213   // want to change the bitwidth. Emit a message at least.
214   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
215     auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
216     LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
217                             << dstAttr << "' for type '" << dstType << "'\n");
218     return dstAttr;
219   }
220 
221   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
222                           << "' illegal: cannot fit into target type '"
223                           << dstType << "'\n");
224   return IntegerAttr();
225 }
226 
227 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
228 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
229 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
230                                   Builder builder) {
231   // Only support converting to float for now.
232   if (!dstType.isF32())
233     return FloatAttr();
234 
235   // Try to convert the source floating-point number to single precision.
236   APFloat dstVal = srcAttr.getValue();
237   bool losesInfo = false;
238   APFloat::opStatus status =
239       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
240   if (status != APFloat::opOK || losesInfo) {
241     LLVM_DEBUG(llvm::dbgs()
242                << srcAttr << " illegal: cannot fit into converted type '"
243                << dstType << "'\n");
244     return FloatAttr();
245   }
246 
247   return builder.getF32FloatAttr(dstVal.convertToFloat());
248 }
249 
250 /// Returns true if the given `type` is a boolean scalar or vector type.
251 static bool isBoolScalarOrVector(Type type) {
252   if (type.isInteger(1))
253     return true;
254   if (auto vecType = type.dyn_cast<VectorType>())
255     return vecType.getElementType().isInteger(1);
256   return false;
257 }
258 
259 //===----------------------------------------------------------------------===//
260 // ConstantOp with composite type
261 //===----------------------------------------------------------------------===//
262 
263 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
264     arith::ConstantOp constOp, OpAdaptor adaptor,
265     ConversionPatternRewriter &rewriter) const {
266   auto srcType = constOp.getType().dyn_cast<ShapedType>();
267   if (!srcType)
268     return failure();
269 
270   // arith.constant should only have vector or tenor types.
271   assert((srcType.isa<VectorType, RankedTensorType>()));
272 
273   auto dstType = getTypeConverter()->convertType(srcType);
274   if (!dstType)
275     return failure();
276 
277   auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
278   ShapedType dstAttrType = dstElementsAttr.getType();
279   if (!dstElementsAttr)
280     return failure();
281 
282   // If the composite type has more than one dimensions, perform linearization.
283   if (srcType.getRank() > 1) {
284     if (srcType.isa<RankedTensorType>()) {
285       dstAttrType = RankedTensorType::get(srcType.getNumElements(),
286                                           srcType.getElementType());
287       dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
288     } else {
289       // TODO: add support for large vectors.
290       return failure();
291     }
292   }
293 
294   Type srcElemType = srcType.getElementType();
295   Type dstElemType;
296   // Tensor types are converted to SPIR-V array types; vector types are
297   // converted to SPIR-V vector/array types.
298   if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
299     dstElemType = arrayType.getElementType();
300   else
301     dstElemType = dstType.cast<VectorType>().getElementType();
302 
303   // If the source and destination element types are different, perform
304   // attribute conversion.
305   if (srcElemType != dstElemType) {
306     SmallVector<Attribute, 8> elements;
307     if (srcElemType.isa<FloatType>()) {
308       for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
309         FloatAttr dstAttr =
310             convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
311         if (!dstAttr)
312           return failure();
313         elements.push_back(dstAttr);
314       }
315     } else if (srcElemType.isInteger(1)) {
316       return failure();
317     } else {
318       for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
319         IntegerAttr dstAttr = convertIntegerAttr(
320             srcAttr, dstElemType.cast<IntegerType>(), rewriter);
321         if (!dstAttr)
322           return failure();
323         elements.push_back(dstAttr);
324       }
325     }
326 
327     // Unfortunately, we cannot use dialect-specific types for element
328     // attributes; element attributes only works with builtin types. So we need
329     // to prepare another converted builtin types for the destination elements
330     // attribute.
331     if (dstAttrType.isa<RankedTensorType>())
332       dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
333     else
334       dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
335 
336     dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
337   }
338 
339   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
340                                                  dstElementsAttr);
341   return success();
342 }
343 
344 //===----------------------------------------------------------------------===//
345 // ConstantOp with scalar type
346 //===----------------------------------------------------------------------===//
347 
348 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
349     arith::ConstantOp constOp, OpAdaptor adaptor,
350     ConversionPatternRewriter &rewriter) const {
351   Type srcType = constOp.getType();
352   if (!srcType.isIntOrIndexOrFloat())
353     return failure();
354 
355   Type dstType = getTypeConverter()->convertType(srcType);
356   if (!dstType)
357     return failure();
358 
359   // Floating-point types.
360   if (srcType.isa<FloatType>()) {
361     auto srcAttr = constOp.getValue().cast<FloatAttr>();
362     auto dstAttr = srcAttr;
363 
364     // Floating-point types not supported in the target environment are all
365     // converted to float type.
366     if (srcType != dstType) {
367       dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
368       if (!dstAttr)
369         return failure();
370     }
371 
372     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
373     return success();
374   }
375 
376   // Bool type.
377   if (srcType.isInteger(1)) {
378     // arith.constant can use 0/1 instead of true/false for i1 values. We need
379     // to handle that here.
380     auto dstAttr = convertBoolAttr(constOp.getValue(), rewriter);
381     if (!dstAttr)
382       return failure();
383     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
384     return success();
385   }
386 
387   // IndexType or IntegerType. Index values are converted to 32-bit integer
388   // values when converting to SPIR-V.
389   auto srcAttr = constOp.getValue().cast<IntegerAttr>();
390   auto dstAttr =
391       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
392   if (!dstAttr)
393     return failure();
394   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
395   return success();
396 }
397 
398 //===----------------------------------------------------------------------===//
399 // RemSIOpPattern
400 //===----------------------------------------------------------------------===//
401 
402 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
403 /// the sign of `signOperand`.
404 ///
405 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
406 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
407 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
408 /// if either operand can be negative. Emulate it via spv.UMod.
409 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
410                                     Value signOperand, OpBuilder &builder) {
411   assert(lhs.getType() == rhs.getType());
412   assert(lhs == signOperand || rhs == signOperand);
413 
414   Type type = lhs.getType();
415 
416   // Calculate the remainder with spv.UMod.
417   Value lhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, lhs);
418   Value rhsAbs = builder.create<spirv::GLSLSAbsOp>(loc, type, rhs);
419   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
420 
421   // Fix the sign.
422   Value isPositive;
423   if (lhs == signOperand)
424     isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
425   else
426     isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
427   Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
428   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
429 }
430 
431 LogicalResult
432 RemSIOpPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
433                                 ConversionPatternRewriter &rewriter) const {
434   Value result = emulateSignedRemainder(op.getLoc(), adaptor.getOperands()[0],
435                                         adaptor.getOperands()[1],
436                                         adaptor.getOperands()[0], rewriter);
437   rewriter.replaceOp(op, result);
438 
439   return success();
440 }
441 
442 //===----------------------------------------------------------------------===//
443 // BitwiseOpPattern
444 //===----------------------------------------------------------------------===//
445 
446 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
447 LogicalResult
448 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
449     Op op, typename Op::Adaptor adaptor,
450     ConversionPatternRewriter &rewriter) const {
451   assert(adaptor.getOperands().size() == 2);
452   auto dstType =
453       this->getTypeConverter()->convertType(op.getResult().getType());
454   if (!dstType)
455     return failure();
456   if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
457     rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
458                                                          adaptor.getOperands());
459   } else {
460     rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
461                                                          adaptor.getOperands());
462   }
463   return success();
464 }
465 
466 //===----------------------------------------------------------------------===//
467 // XOrIOpLogicalPattern
468 //===----------------------------------------------------------------------===//
469 
470 LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
471     arith::XOrIOp op, OpAdaptor adaptor,
472     ConversionPatternRewriter &rewriter) const {
473   assert(adaptor.getOperands().size() == 2);
474 
475   if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
476     return failure();
477 
478   auto dstType = getTypeConverter()->convertType(op.getType());
479   if (!dstType)
480     return failure();
481   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
482                                                    adaptor.getOperands());
483 
484   return success();
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // XOrIOpBooleanPattern
489 //===----------------------------------------------------------------------===//
490 
491 LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
492     arith::XOrIOp op, OpAdaptor adaptor,
493     ConversionPatternRewriter &rewriter) const {
494   assert(adaptor.getOperands().size() == 2);
495 
496   if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
497     return failure();
498 
499   auto dstType = getTypeConverter()->convertType(op.getType());
500   if (!dstType)
501     return failure();
502   rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
503                                                         adaptor.getOperands());
504   return success();
505 }
506 
507 //===----------------------------------------------------------------------===//
508 // UIToFPI1Pattern
509 //===----------------------------------------------------------------------===//
510 
511 LogicalResult
512 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
513                                  ConversionPatternRewriter &rewriter) const {
514   auto srcType = adaptor.getOperands().front().getType();
515   if (!isBoolScalarOrVector(srcType))
516     return failure();
517 
518   auto dstType =
519       this->getTypeConverter()->convertType(op.getResult().getType());
520   Location loc = op.getLoc();
521   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
522   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
523   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
524       op, dstType, adaptor.getOperands().front(), one, zero);
525   return success();
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // ExtUII1Pattern
530 //===----------------------------------------------------------------------===//
531 
532 LogicalResult
533 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
534                                 ConversionPatternRewriter &rewriter) const {
535   auto srcType = adaptor.getOperands().front().getType();
536   if (!isBoolScalarOrVector(srcType))
537     return failure();
538 
539   auto dstType =
540       this->getTypeConverter()->convertType(op.getResult().getType());
541   Location loc = op.getLoc();
542   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
543   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
544   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
545       op, dstType, adaptor.getOperands().front(), one, zero);
546   return success();
547 }
548 
549 //===----------------------------------------------------------------------===//
550 // TruncII1Pattern
551 //===----------------------------------------------------------------------===//
552 
553 LogicalResult
554 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
555                                  ConversionPatternRewriter &rewriter) const {
556   auto dstType =
557       this->getTypeConverter()->convertType(op.getResult().getType());
558   if (!isBoolScalarOrVector(dstType))
559     return failure();
560 
561   Location loc = op.getLoc();
562   auto srcType = adaptor.getOperands().front().getType();
563   // Check if (x & 1) == 1.
564   Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
565   Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
566       loc, srcType, adaptor.getOperands()[0], mask);
567   Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
568 
569   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
570   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
571   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
572   return success();
573 }
574 
575 //===----------------------------------------------------------------------===//
576 // TypeCastingOpPattern
577 //===----------------------------------------------------------------------===//
578 
579 template <typename Op, typename SPIRVOp>
580 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
581     Op op, typename Op::Adaptor adaptor,
582     ConversionPatternRewriter &rewriter) const {
583   assert(adaptor.getOperands().size() == 1);
584   auto srcType = adaptor.getOperands().front().getType();
585   auto dstType =
586       this->getTypeConverter()->convertType(op.getResult().getType());
587   if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
588     return failure();
589   if (dstType == srcType) {
590     // Due to type conversion, we are seeing the same source and target type.
591     // Then we can just erase this operation by forwarding its operand.
592     rewriter.replaceOp(op, adaptor.getOperands().front());
593   } else {
594     rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
595                                                   adaptor.getOperands());
596   }
597   return success();
598 }
599 
600 //===----------------------------------------------------------------------===//
601 // CmpIOpBooleanPattern
602 //===----------------------------------------------------------------------===//
603 
604 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
605     arith::CmpIOp op, OpAdaptor adaptor,
606     ConversionPatternRewriter &rewriter) const {
607   Type operandType = op.getLhs().getType();
608   if (!isBoolScalarOrVector(operandType))
609     return failure();
610 
611   switch (op.getPredicate()) {
612 #define DISPATCH(cmpPredicate, spirvOp)                                        \
613   case cmpPredicate:                                                           \
614     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
615                                          adaptor.lhs(), adaptor.rhs());        \
616     return success();
617 
618     DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
619     DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
620 
621 #undef DISPATCH
622   default:;
623   }
624   return failure();
625 }
626 
627 //===----------------------------------------------------------------------===//
628 // CmpIOpPattern
629 //===----------------------------------------------------------------------===//
630 
631 LogicalResult
632 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
633                                ConversionPatternRewriter &rewriter) const {
634   Type operandType = op.getLhs().getType();
635   if (isBoolScalarOrVector(operandType))
636     return failure();
637 
638   switch (op.getPredicate()) {
639 #define DISPATCH(cmpPredicate, spirvOp)                                        \
640   case cmpPredicate:                                                           \
641     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
642         operandType != this->getTypeConverter()->convertType(operandType)) {   \
643       return op.emitError(                                                     \
644           "bitwidth emulation is not implemented yet on unsigned op");         \
645     }                                                                          \
646     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
647                                          adaptor.lhs(), adaptor.rhs());        \
648     return success();
649 
650     DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
651     DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
652     DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
653     DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
654     DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
655     DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
656     DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
657     DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
658     DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
659     DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
660 
661 #undef DISPATCH
662   }
663   return failure();
664 }
665 
666 //===----------------------------------------------------------------------===//
667 // CmpFOpPattern
668 //===----------------------------------------------------------------------===//
669 
670 LogicalResult
671 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
672                                ConversionPatternRewriter &rewriter) const {
673   switch (op.getPredicate()) {
674 #define DISPATCH(cmpPredicate, spirvOp)                                        \
675   case cmpPredicate:                                                           \
676     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
677                                          adaptor.lhs(), adaptor.rhs());        \
678     return success();
679 
680     // Ordered.
681     DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
682     DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
683     DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
684     DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
685     DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
686     DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
687     // Unordered.
688     DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
689     DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
690     DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
691     DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
692     DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
693     DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
694 
695 #undef DISPATCH
696 
697   default:
698     break;
699   }
700   return failure();
701 }
702 
703 //===----------------------------------------------------------------------===//
704 // CmpFOpNanKernelPattern
705 //===----------------------------------------------------------------------===//
706 
707 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
708     arith::CmpFOp op, OpAdaptor adaptor,
709     ConversionPatternRewriter &rewriter) const {
710   if (op.getPredicate() == arith::CmpFPredicate::ORD) {
711     rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
712                                                   adaptor.getRhs());
713     return success();
714   }
715 
716   if (op.getPredicate() == arith::CmpFPredicate::UNO) {
717     rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
718                                                     adaptor.getRhs());
719     return success();
720   }
721 
722   return failure();
723 }
724 
725 //===----------------------------------------------------------------------===//
726 // CmpFOpNanNonePattern
727 //===----------------------------------------------------------------------===//
728 
729 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
730     arith::CmpFOp op, OpAdaptor adaptor,
731     ConversionPatternRewriter &rewriter) const {
732   if (op.getPredicate() != arith::CmpFPredicate::ORD &&
733       op.getPredicate() != arith::CmpFPredicate::UNO)
734     return failure();
735 
736   Location loc = op.getLoc();
737 
738   Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
739   Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
740 
741   Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
742   if (op.getPredicate() == arith::CmpFPredicate::ORD)
743     replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
744 
745   rewriter.replaceOp(op, replace);
746   return success();
747 }
748 
749 //===----------------------------------------------------------------------===//
750 // Pattern Population
751 //===----------------------------------------------------------------------===//
752 
753 void mlir::arith::populateArithmeticToSPIRVPatterns(
754     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
755   // clang-format off
756   patterns.add<
757     ConstantCompositeOpPattern,
758     ConstantScalarOpPattern,
759     spirv::UnaryAndBinaryOpPattern<arith::AddIOp, spirv::IAddOp>,
760     spirv::UnaryAndBinaryOpPattern<arith::SubIOp, spirv::ISubOp>,
761     spirv::UnaryAndBinaryOpPattern<arith::MulIOp, spirv::IMulOp>,
762     spirv::UnaryAndBinaryOpPattern<arith::DivUIOp, spirv::UDivOp>,
763     spirv::UnaryAndBinaryOpPattern<arith::DivSIOp, spirv::SDivOp>,
764     spirv::UnaryAndBinaryOpPattern<arith::RemUIOp, spirv::UModOp>,
765     RemSIOpPattern,
766     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
767     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
768     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
769     spirv::UnaryAndBinaryOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
770     spirv::UnaryAndBinaryOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
771     spirv::UnaryAndBinaryOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
772     spirv::UnaryAndBinaryOpPattern<arith::NegFOp, spirv::FNegateOp>,
773     spirv::UnaryAndBinaryOpPattern<arith::AddFOp, spirv::FAddOp>,
774     spirv::UnaryAndBinaryOpPattern<arith::SubFOp, spirv::FSubOp>,
775     spirv::UnaryAndBinaryOpPattern<arith::MulFOp, spirv::FMulOp>,
776     spirv::UnaryAndBinaryOpPattern<arith::DivFOp, spirv::FDivOp>,
777     spirv::UnaryAndBinaryOpPattern<arith::RemFOp, spirv::FRemOp>,
778     TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
779     TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
780     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
781     TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
782     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
783     TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
784     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
785     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
786     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
787     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
788     CmpIOpBooleanPattern, CmpIOpPattern,
789     CmpFOpNanNonePattern, CmpFOpPattern
790   >(typeConverter, patterns.getContext());
791   // clang-format on
792 
793   // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
794   // capability is available.
795   patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
796                                        /*benefit=*/2);
797 }
798 
799 //===----------------------------------------------------------------------===//
800 // Pass Definition
801 //===----------------------------------------------------------------------===//
802 
803 namespace {
804 struct ConvertArithmeticToSPIRVPass
805     : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
806   void runOnFunction() override {
807     auto module = getOperation()->getParentOfType<ModuleOp>();
808     auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
809     auto target = SPIRVConversionTarget::get(targetAttr);
810 
811     SPIRVTypeConverter::Options options;
812     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
813     SPIRVTypeConverter typeConverter(targetAttr, options);
814 
815     RewritePatternSet patterns(&getContext());
816     mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
817 
818     if (failed(applyPartialConversion(getOperation(), *target,
819                                       std::move(patterns))))
820       signalPassFailure();
821   }
822 };
823 } // end anonymous namespace
824 
825 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
826   return std::make_unique<ConvertArithmeticToSPIRVPass>();
827 }
828