1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
10 #include "../PassDetail.h"
11 #include "../SPIRVCommon/Pattern.h"
12 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
13 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
15 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "arith-to-spirv-pattern"
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Operation Conversion
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 
28 /// Converts composite arith.constant operation to spv.Constant.
29 struct ConstantCompositeOpPattern final
30     : public OpConversionPattern<arith::ConstantOp> {
31   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
32 
33   LogicalResult
34   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
35                   ConversionPatternRewriter &rewriter) const override;
36 };
37 
38 /// Converts scalar arith.constant operation to spv.Constant.
39 struct ConstantScalarOpPattern final
40     : public OpConversionPattern<arith::ConstantOp> {
41   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
42 
43   LogicalResult
44   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
45                   ConversionPatternRewriter &rewriter) const override;
46 };
47 
48 /// Converts arith.remsi to GLSL SPIR-V ops.
49 ///
50 /// This cannot be merged into the template unary/binary pattern due to Vulkan
51 /// restrictions over spv.SRem and spv.SMod.
52 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> {
53   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
54 
55   LogicalResult
56   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
57                   ConversionPatternRewriter &rewriter) const override;
58 };
59 
60 /// Converts arith.remsi to OpenCL SPIR-V ops.
61 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> {
62   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
63 
64   LogicalResult
65   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
66                   ConversionPatternRewriter &rewriter) const override;
67 };
68 
69 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
70 /// other than the BinaryOpPatternPattern because if the operands are boolean
71 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
72 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
73 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
74 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
75   using OpConversionPattern<Op>::OpConversionPattern;
76 
77   LogicalResult
78   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
79                   ConversionPatternRewriter &rewriter) const override;
80 };
81 
82 /// Converts arith.xori to SPIR-V operations.
83 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
84   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
85 
86   LogicalResult
87   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
88                   ConversionPatternRewriter &rewriter) const override;
89 };
90 
91 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
92 /// vector of i1.
93 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
94   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
95 
96   LogicalResult
97   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
98                   ConversionPatternRewriter &rewriter) const override;
99 };
100 
101 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
102 /// i1.
103 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
104   using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
105 
106   LogicalResult
107   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
108                   ConversionPatternRewriter &rewriter) const override;
109 };
110 
111 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
112 /// i1.
113 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
114   using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
115 
116   LogicalResult
117   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
118                   ConversionPatternRewriter &rewriter) const override;
119 };
120 
121 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
122 /// i1.
123 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
124   using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
125 
126   LogicalResult
127   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
128                   ConversionPatternRewriter &rewriter) const override;
129 };
130 
131 /// Converts type-casting standard operations to SPIR-V operations.
132 template <typename Op, typename SPIRVOp>
133 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
134   using OpConversionPattern<Op>::OpConversionPattern;
135 
136   LogicalResult
137   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
138                   ConversionPatternRewriter &rewriter) const override;
139 };
140 
141 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
142 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
143 public:
144   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
145 
146   LogicalResult
147   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
148                   ConversionPatternRewriter &rewriter) const override;
149 };
150 
151 /// Converts integer compare operation to SPIR-V ops.
152 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
153 public:
154   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
155 
156   LogicalResult
157   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
158                   ConversionPatternRewriter &rewriter) const override;
159 };
160 
161 /// Converts floating-point comparison operations to SPIR-V ops.
162 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
163 public:
164   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
165 
166   LogicalResult
167   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
168                   ConversionPatternRewriter &rewriter) const override;
169 };
170 
171 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
172 /// Kernel capability.
173 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
174 public:
175   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
176 
177   LogicalResult
178   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
179                   ConversionPatternRewriter &rewriter) const override;
180 };
181 
182 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
183 /// require additional capability.
184 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
185 public:
186   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
187 
188   LogicalResult
189   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
190                   ConversionPatternRewriter &rewriter) const override;
191 };
192 
193 /// Converts arith.select to spv.Select.
194 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
195 public:
196   using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
197   LogicalResult
198   matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
199                   ConversionPatternRewriter &rewriter) const override;
200 };
201 
202 } // namespace
203 
204 //===----------------------------------------------------------------------===//
205 // Conversion Helpers
206 //===----------------------------------------------------------------------===//
207 
208 /// Converts the given `srcAttr` into a boolean attribute if it holds an
209 /// integral value. Returns null attribute if conversion fails.
210 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
211   if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
212     return boolAttr;
213   if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
214     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
215   return BoolAttr();
216 }
217 
218 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
219 /// Returns null attribute if conversion fails.
220 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
221                                       Builder builder) {
222   // If the source number uses less active bits than the target bitwidth, then
223   // it should be safe to convert.
224   if (srcAttr.getValue().isIntN(dstType.getWidth()))
225     return builder.getIntegerAttr(dstType, srcAttr.getInt());
226 
227   // XXX: Try again by interpreting the source number as a signed value.
228   // Although integers in the standard dialect are signless, they can represent
229   // a signed number. It's the operation decides how to interpret. This is
230   // dangerous, but it seems there is no good way of handling this if we still
231   // want to change the bitwidth. Emit a message at least.
232   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
233     auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
234     LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
235                             << dstAttr << "' for type '" << dstType << "'\n");
236     return dstAttr;
237   }
238 
239   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
240                           << "' illegal: cannot fit into target type '"
241                           << dstType << "'\n");
242   return IntegerAttr();
243 }
244 
245 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
246 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
247 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
248                                   Builder builder) {
249   // Only support converting to float for now.
250   if (!dstType.isF32())
251     return FloatAttr();
252 
253   // Try to convert the source floating-point number to single precision.
254   APFloat dstVal = srcAttr.getValue();
255   bool losesInfo = false;
256   APFloat::opStatus status =
257       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
258   if (status != APFloat::opOK || losesInfo) {
259     LLVM_DEBUG(llvm::dbgs()
260                << srcAttr << " illegal: cannot fit into converted type '"
261                << dstType << "'\n");
262     return FloatAttr();
263   }
264 
265   return builder.getF32FloatAttr(dstVal.convertToFloat());
266 }
267 
268 /// Returns true if the given `type` is a boolean scalar or vector type.
269 static bool isBoolScalarOrVector(Type type) {
270   if (type.isInteger(1))
271     return true;
272   if (auto vecType = type.dyn_cast<VectorType>())
273     return vecType.getElementType().isInteger(1);
274   return false;
275 }
276 
277 //===----------------------------------------------------------------------===//
278 // ConstantOp with composite type
279 //===----------------------------------------------------------------------===//
280 
281 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
282     arith::ConstantOp constOp, OpAdaptor adaptor,
283     ConversionPatternRewriter &rewriter) const {
284   auto srcType = constOp.getType().dyn_cast<ShapedType>();
285   if (!srcType || srcType.getNumElements() == 1)
286     return failure();
287 
288   // arith.constant should only have vector or tenor types.
289   assert((srcType.isa<VectorType, RankedTensorType>()));
290 
291   auto dstType = getTypeConverter()->convertType(srcType);
292   if (!dstType)
293     return failure();
294 
295   auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
296   ShapedType dstAttrType = dstElementsAttr.getType();
297   if (!dstElementsAttr)
298     return failure();
299 
300   // If the composite type has more than one dimensions, perform linearization.
301   if (srcType.getRank() > 1) {
302     if (srcType.isa<RankedTensorType>()) {
303       dstAttrType = RankedTensorType::get(srcType.getNumElements(),
304                                           srcType.getElementType());
305       dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
306     } else {
307       // TODO: add support for large vectors.
308       return failure();
309     }
310   }
311 
312   Type srcElemType = srcType.getElementType();
313   Type dstElemType;
314   // Tensor types are converted to SPIR-V array types; vector types are
315   // converted to SPIR-V vector/array types.
316   if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
317     dstElemType = arrayType.getElementType();
318   else
319     dstElemType = dstType.cast<VectorType>().getElementType();
320 
321   // If the source and destination element types are different, perform
322   // attribute conversion.
323   if (srcElemType != dstElemType) {
324     SmallVector<Attribute, 8> elements;
325     if (srcElemType.isa<FloatType>()) {
326       for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
327         FloatAttr dstAttr =
328             convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
329         if (!dstAttr)
330           return failure();
331         elements.push_back(dstAttr);
332       }
333     } else if (srcElemType.isInteger(1)) {
334       return failure();
335     } else {
336       for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
337         IntegerAttr dstAttr = convertIntegerAttr(
338             srcAttr, dstElemType.cast<IntegerType>(), rewriter);
339         if (!dstAttr)
340           return failure();
341         elements.push_back(dstAttr);
342       }
343     }
344 
345     // Unfortunately, we cannot use dialect-specific types for element
346     // attributes; element attributes only works with builtin types. So we need
347     // to prepare another converted builtin types for the destination elements
348     // attribute.
349     if (dstAttrType.isa<RankedTensorType>())
350       dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
351     else
352       dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
353 
354     dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
355   }
356 
357   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
358                                                  dstElementsAttr);
359   return success();
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // ConstantOp with scalar type
364 //===----------------------------------------------------------------------===//
365 
366 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
367     arith::ConstantOp constOp, OpAdaptor adaptor,
368     ConversionPatternRewriter &rewriter) const {
369   Type srcType = constOp.getType();
370   if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
371     if (shapedType.getNumElements() != 1)
372       return failure();
373     srcType = shapedType.getElementType();
374   }
375   if (!srcType.isIntOrIndexOrFloat())
376     return failure();
377 
378   Attribute cstAttr = constOp.getValue();
379   if (cstAttr.getType().isa<ShapedType>())
380     cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
381 
382   Type dstType = getTypeConverter()->convertType(srcType);
383   if (!dstType)
384     return failure();
385 
386   // Floating-point types.
387   if (srcType.isa<FloatType>()) {
388     auto srcAttr = cstAttr.cast<FloatAttr>();
389     auto dstAttr = srcAttr;
390 
391     // Floating-point types not supported in the target environment are all
392     // converted to float type.
393     if (srcType != dstType) {
394       dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
395       if (!dstAttr)
396         return failure();
397     }
398 
399     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
400     return success();
401   }
402 
403   // Bool type.
404   if (srcType.isInteger(1)) {
405     // arith.constant can use 0/1 instead of true/false for i1 values. We need
406     // to handle that here.
407     auto dstAttr = convertBoolAttr(cstAttr, rewriter);
408     if (!dstAttr)
409       return failure();
410     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
411     return success();
412   }
413 
414   // IndexType or IntegerType. Index values are converted to 32-bit integer
415   // values when converting to SPIR-V.
416   auto srcAttr = cstAttr.cast<IntegerAttr>();
417   auto dstAttr =
418       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
419   if (!dstAttr)
420     return failure();
421   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
422   return success();
423 }
424 
425 //===----------------------------------------------------------------------===//
426 // RemSIOpGLSLPattern
427 //===----------------------------------------------------------------------===//
428 
429 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
430 /// the sign of `signOperand`.
431 ///
432 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
433 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
434 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
435 /// if either operand can be negative. Emulate it via spv.UMod.
436 template <typename SignedAbsOp>
437 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
438                                     Value signOperand, OpBuilder &builder) {
439   assert(lhs.getType() == rhs.getType());
440   assert(lhs == signOperand || rhs == signOperand);
441 
442   Type type = lhs.getType();
443 
444   // Calculate the remainder with spv.UMod.
445   Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
446   Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
447   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
448 
449   // Fix the sign.
450   Value isPositive;
451   if (lhs == signOperand)
452     isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
453   else
454     isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
455   Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
456   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
457 }
458 
459 LogicalResult
460 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
461                                     ConversionPatternRewriter &rewriter) const {
462   Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>(
463       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
464       adaptor.getOperands()[0], rewriter);
465   rewriter.replaceOp(op, result);
466 
467   return success();
468 }
469 
470 //===----------------------------------------------------------------------===//
471 // RemSIOpOCLPattern
472 //===----------------------------------------------------------------------===//
473 
474 LogicalResult
475 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
476                                    ConversionPatternRewriter &rewriter) const {
477   Value result = emulateSignedRemainder<spirv::OCLSAbsOp>(
478       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
479       adaptor.getOperands()[0], rewriter);
480   rewriter.replaceOp(op, result);
481 
482   return success();
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // BitwiseOpPattern
487 //===----------------------------------------------------------------------===//
488 
489 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
490 LogicalResult
491 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
492     Op op, typename Op::Adaptor adaptor,
493     ConversionPatternRewriter &rewriter) const {
494   assert(adaptor.getOperands().size() == 2);
495   auto dstType =
496       this->getTypeConverter()->convertType(op.getResult().getType());
497   if (!dstType)
498     return failure();
499   if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
500     rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
501                                                          adaptor.getOperands());
502   } else {
503     rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
504                                                          adaptor.getOperands());
505   }
506   return success();
507 }
508 
509 //===----------------------------------------------------------------------===//
510 // XOrIOpLogicalPattern
511 //===----------------------------------------------------------------------===//
512 
513 LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
514     arith::XOrIOp op, OpAdaptor adaptor,
515     ConversionPatternRewriter &rewriter) const {
516   assert(adaptor.getOperands().size() == 2);
517 
518   if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
519     return failure();
520 
521   auto dstType = getTypeConverter()->convertType(op.getType());
522   if (!dstType)
523     return failure();
524   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
525                                                    adaptor.getOperands());
526 
527   return success();
528 }
529 
530 //===----------------------------------------------------------------------===//
531 // XOrIOpBooleanPattern
532 //===----------------------------------------------------------------------===//
533 
534 LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
535     arith::XOrIOp op, OpAdaptor adaptor,
536     ConversionPatternRewriter &rewriter) const {
537   assert(adaptor.getOperands().size() == 2);
538 
539   if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
540     return failure();
541 
542   auto dstType = getTypeConverter()->convertType(op.getType());
543   if (!dstType)
544     return failure();
545   rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
546                                                         adaptor.getOperands());
547   return success();
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // UIToFPI1Pattern
552 //===----------------------------------------------------------------------===//
553 
554 LogicalResult
555 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
556                                  ConversionPatternRewriter &rewriter) const {
557   auto srcType = adaptor.getOperands().front().getType();
558   if (!isBoolScalarOrVector(srcType))
559     return failure();
560 
561   auto dstType =
562       this->getTypeConverter()->convertType(op.getResult().getType());
563   Location loc = op.getLoc();
564   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
565   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
566   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
567       op, dstType, adaptor.getOperands().front(), one, zero);
568   return success();
569 }
570 
571 //===----------------------------------------------------------------------===//
572 // ExtUII1Pattern
573 //===----------------------------------------------------------------------===//
574 
575 LogicalResult
576 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
577                                 ConversionPatternRewriter &rewriter) const {
578   auto srcType = adaptor.getOperands().front().getType();
579   if (!isBoolScalarOrVector(srcType))
580     return failure();
581 
582   auto dstType =
583       this->getTypeConverter()->convertType(op.getResult().getType());
584   Location loc = op.getLoc();
585   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
586   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
587   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
588       op, dstType, adaptor.getOperands().front(), one, zero);
589   return success();
590 }
591 
592 //===----------------------------------------------------------------------===//
593 // TruncII1Pattern
594 //===----------------------------------------------------------------------===//
595 
596 LogicalResult
597 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
598                                  ConversionPatternRewriter &rewriter) const {
599   auto dstType =
600       this->getTypeConverter()->convertType(op.getResult().getType());
601   if (!isBoolScalarOrVector(dstType))
602     return failure();
603 
604   Location loc = op.getLoc();
605   auto srcType = adaptor.getOperands().front().getType();
606   // Check if (x & 1) == 1.
607   Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
608   Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
609       loc, srcType, adaptor.getOperands()[0], mask);
610   Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
611 
612   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
613   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
614   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
615   return success();
616 }
617 
618 //===----------------------------------------------------------------------===//
619 // TypeCastingOpPattern
620 //===----------------------------------------------------------------------===//
621 
622 template <typename Op, typename SPIRVOp>
623 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
624     Op op, typename Op::Adaptor adaptor,
625     ConversionPatternRewriter &rewriter) const {
626   assert(adaptor.getOperands().size() == 1);
627   auto srcType = adaptor.getOperands().front().getType();
628   auto dstType =
629       this->getTypeConverter()->convertType(op.getResult().getType());
630   if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
631     return failure();
632   if (dstType == srcType) {
633     // Due to type conversion, we are seeing the same source and target type.
634     // Then we can just erase this operation by forwarding its operand.
635     rewriter.replaceOp(op, adaptor.getOperands().front());
636   } else {
637     rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
638                                                   adaptor.getOperands());
639   }
640   return success();
641 }
642 
643 //===----------------------------------------------------------------------===//
644 // CmpIOpBooleanPattern
645 //===----------------------------------------------------------------------===//
646 
647 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
648     arith::CmpIOp op, OpAdaptor adaptor,
649     ConversionPatternRewriter &rewriter) const {
650   Type operandType = op.getLhs().getType();
651   if (!isBoolScalarOrVector(operandType))
652     return failure();
653 
654   switch (op.getPredicate()) {
655 #define DISPATCH(cmpPredicate, spirvOp)                                        \
656   case cmpPredicate:                                                           \
657     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
658                                          adaptor.getLhs(), adaptor.getRhs());  \
659     return success();
660 
661     DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
662     DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
663 
664 #undef DISPATCH
665   default:;
666   }
667   return failure();
668 }
669 
670 //===----------------------------------------------------------------------===//
671 // CmpIOpPattern
672 //===----------------------------------------------------------------------===//
673 
674 LogicalResult
675 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
676                                ConversionPatternRewriter &rewriter) const {
677   Type operandType = op.getLhs().getType();
678   if (isBoolScalarOrVector(operandType))
679     return failure();
680 
681   switch (op.getPredicate()) {
682 #define DISPATCH(cmpPredicate, spirvOp)                                        \
683   case cmpPredicate:                                                           \
684     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
685         operandType != this->getTypeConverter()->convertType(operandType)) {   \
686       return op.emitError(                                                     \
687           "bitwidth emulation is not implemented yet on unsigned op");         \
688     }                                                                          \
689     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
690                                          adaptor.getLhs(), adaptor.getRhs());  \
691     return success();
692 
693     DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
694     DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
695     DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
696     DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
697     DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
698     DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
699     DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
700     DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
701     DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
702     DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
703 
704 #undef DISPATCH
705   }
706   return failure();
707 }
708 
709 //===----------------------------------------------------------------------===//
710 // CmpFOpPattern
711 //===----------------------------------------------------------------------===//
712 
713 LogicalResult
714 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
715                                ConversionPatternRewriter &rewriter) const {
716   switch (op.getPredicate()) {
717 #define DISPATCH(cmpPredicate, spirvOp)                                        \
718   case cmpPredicate:                                                           \
719     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
720                                          adaptor.getLhs(), adaptor.getRhs());  \
721     return success();
722 
723     // Ordered.
724     DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
725     DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
726     DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
727     DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
728     DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
729     DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
730     // Unordered.
731     DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
732     DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
733     DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
734     DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
735     DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
736     DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
737 
738 #undef DISPATCH
739 
740   default:
741     break;
742   }
743   return failure();
744 }
745 
746 //===----------------------------------------------------------------------===//
747 // CmpFOpNanKernelPattern
748 //===----------------------------------------------------------------------===//
749 
750 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
751     arith::CmpFOp op, OpAdaptor adaptor,
752     ConversionPatternRewriter &rewriter) const {
753   if (op.getPredicate() == arith::CmpFPredicate::ORD) {
754     rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
755                                                   adaptor.getRhs());
756     return success();
757   }
758 
759   if (op.getPredicate() == arith::CmpFPredicate::UNO) {
760     rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
761                                                     adaptor.getRhs());
762     return success();
763   }
764 
765   return failure();
766 }
767 
768 //===----------------------------------------------------------------------===//
769 // CmpFOpNanNonePattern
770 //===----------------------------------------------------------------------===//
771 
772 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
773     arith::CmpFOp op, OpAdaptor adaptor,
774     ConversionPatternRewriter &rewriter) const {
775   if (op.getPredicate() != arith::CmpFPredicate::ORD &&
776       op.getPredicate() != arith::CmpFPredicate::UNO)
777     return failure();
778 
779   Location loc = op.getLoc();
780 
781   Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
782   Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
783 
784   Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
785   if (op.getPredicate() == arith::CmpFPredicate::ORD)
786     replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
787 
788   rewriter.replaceOp(op, replace);
789   return success();
790 }
791 
792 //===----------------------------------------------------------------------===//
793 // SelectOpPattern
794 //===----------------------------------------------------------------------===//
795 
796 LogicalResult
797 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
798                                  ConversionPatternRewriter &rewriter) const {
799   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
800                                                adaptor.getTrueValue(),
801                                                adaptor.getFalseValue());
802   return success();
803 }
804 
805 //===----------------------------------------------------------------------===//
806 // Pattern Population
807 //===----------------------------------------------------------------------===//
808 
809 void mlir::arith::populateArithmeticToSPIRVPatterns(
810     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
811   // clang-format off
812   patterns.add<
813     ConstantCompositeOpPattern,
814     ConstantScalarOpPattern,
815     spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
816     spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
817     spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
818     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
819     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
820     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
821     RemSIOpGLSLPattern, RemSIOpOCLPattern,
822     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
823     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
824     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
825     spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
826     spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
827     spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
828     spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
829     spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
830     spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
831     spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
832     spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
833     spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
834     TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
835     TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
836     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
837     TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
838     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
839     TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
840     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
841     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
842     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
843     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
844     CmpIOpBooleanPattern, CmpIOpPattern,
845     CmpFOpNanNonePattern, CmpFOpPattern,
846     SelectOpPattern
847   >(typeConverter, patterns.getContext());
848   // clang-format on
849 
850   // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
851   // capability is available.
852   patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
853                                        /*benefit=*/2);
854 }
855 
856 //===----------------------------------------------------------------------===//
857 // Pass Definition
858 //===----------------------------------------------------------------------===//
859 
860 namespace {
861 struct ConvertArithmeticToSPIRVPass
862     : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
863   void runOnOperation() override {
864     auto module = getOperation()->getParentOfType<ModuleOp>();
865     auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
866     auto target = SPIRVConversionTarget::get(targetAttr);
867 
868     SPIRVTypeConverter::Options options;
869     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
870     SPIRVTypeConverter typeConverter(targetAttr, options);
871 
872     RewritePatternSet patterns(&getContext());
873     mlir::arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
874 
875     if (failed(applyPartialConversion(getOperation(), *target,
876                                       std::move(patterns))))
877       signalPassFailure();
878   }
879 };
880 } // namespace
881 
882 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
883   return std::make_unique<ConvertArithmeticToSPIRVPass>();
884 }
885