1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
10 #include "../PassDetail.h"
11 #include "../SPIRVCommon/Pattern.h"
12 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17 #include "llvm/Support/Debug.h"
18 
19 #define DEBUG_TYPE "arith-to-spirv-pattern"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Operation Conversion
25 //===----------------------------------------------------------------------===//
26 
27 namespace {
28 
29 /// Converts composite arith.constant operation to spv.Constant.
30 struct ConstantCompositeOpPattern final
31     : public OpConversionPattern<arith::ConstantOp> {
32   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
33 
34   LogicalResult
35   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
36                   ConversionPatternRewriter &rewriter) const override;
37 };
38 
39 /// Converts scalar arith.constant operation to spv.Constant.
40 struct ConstantScalarOpPattern final
41     : public OpConversionPattern<arith::ConstantOp> {
42   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
43 
44   LogicalResult
45   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
46                   ConversionPatternRewriter &rewriter) const override;
47 };
48 
49 /// Converts arith.remsi to GLSL SPIR-V ops.
50 ///
51 /// This cannot be merged into the template unary/binary pattern due to Vulkan
52 /// restrictions over spv.SRem and spv.SMod.
53 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> {
54   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
55 
56   LogicalResult
57   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
58                   ConversionPatternRewriter &rewriter) const override;
59 };
60 
61 /// Converts arith.remsi to OpenCL SPIR-V ops.
62 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> {
63   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
64 
65   LogicalResult
66   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
67                   ConversionPatternRewriter &rewriter) const override;
68 };
69 
70 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
71 /// other than the BinaryOpPatternPattern because if the operands are boolean
72 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
73 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
74 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
75 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
76   using OpConversionPattern<Op>::OpConversionPattern;
77 
78   LogicalResult
79   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
80                   ConversionPatternRewriter &rewriter) const override;
81 };
82 
83 /// Converts arith.xori to SPIR-V operations.
84 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
85   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
86 
87   LogicalResult
88   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
89                   ConversionPatternRewriter &rewriter) const override;
90 };
91 
92 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
93 /// vector of i1.
94 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
95   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
96 
97   LogicalResult
98   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
99                   ConversionPatternRewriter &rewriter) const override;
100 };
101 
102 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
103 /// i1.
104 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
105   using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
106 
107   LogicalResult
108   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
109                   ConversionPatternRewriter &rewriter) const override;
110 };
111 
112 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
113 /// i1.
114 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
115   using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
116 
117   LogicalResult
118   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
119                   ConversionPatternRewriter &rewriter) const override;
120 };
121 
122 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
123 /// i1.
124 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
125   using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
126 
127   LogicalResult
128   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
129                   ConversionPatternRewriter &rewriter) const override;
130 };
131 
132 /// Converts type-casting standard operations to SPIR-V operations.
133 template <typename Op, typename SPIRVOp>
134 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
135   using OpConversionPattern<Op>::OpConversionPattern;
136 
137   LogicalResult
138   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
139                   ConversionPatternRewriter &rewriter) const override;
140 };
141 
142 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
143 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
144 public:
145   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
146 
147   LogicalResult
148   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
149                   ConversionPatternRewriter &rewriter) const override;
150 };
151 
152 /// Converts integer compare operation to SPIR-V ops.
153 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
154 public:
155   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
156 
157   LogicalResult
158   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
159                   ConversionPatternRewriter &rewriter) const override;
160 };
161 
162 /// Converts floating-point comparison operations to SPIR-V ops.
163 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
164 public:
165   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
166 
167   LogicalResult
168   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
169                   ConversionPatternRewriter &rewriter) const override;
170 };
171 
172 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
173 /// Kernel capability.
174 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
175 public:
176   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
177 
178   LogicalResult
179   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
180                   ConversionPatternRewriter &rewriter) const override;
181 };
182 
183 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
184 /// require additional capability.
185 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
186 public:
187   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
188 
189   LogicalResult
190   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
191                   ConversionPatternRewriter &rewriter) const override;
192 };
193 
194 /// Converts arith.select to spv.Select.
195 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
196 public:
197   using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
198   LogicalResult
199   matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
200                   ConversionPatternRewriter &rewriter) const override;
201 };
202 
203 } // namespace
204 
205 //===----------------------------------------------------------------------===//
206 // Conversion Helpers
207 //===----------------------------------------------------------------------===//
208 
209 /// Converts the given `srcAttr` into a boolean attribute if it holds an
210 /// integral value. Returns null attribute if conversion fails.
211 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
212   if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
213     return boolAttr;
214   if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
215     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
216   return BoolAttr();
217 }
218 
219 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
220 /// Returns null attribute if conversion fails.
221 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
222                                       Builder builder) {
223   // If the source number uses less active bits than the target bitwidth, then
224   // it should be safe to convert.
225   if (srcAttr.getValue().isIntN(dstType.getWidth()))
226     return builder.getIntegerAttr(dstType, srcAttr.getInt());
227 
228   // XXX: Try again by interpreting the source number as a signed value.
229   // Although integers in the standard dialect are signless, they can represent
230   // a signed number. It's the operation decides how to interpret. This is
231   // dangerous, but it seems there is no good way of handling this if we still
232   // want to change the bitwidth. Emit a message at least.
233   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
234     auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
235     LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
236                             << dstAttr << "' for type '" << dstType << "'\n");
237     return dstAttr;
238   }
239 
240   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
241                           << "' illegal: cannot fit into target type '"
242                           << dstType << "'\n");
243   return IntegerAttr();
244 }
245 
246 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
247 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
248 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
249                                   Builder builder) {
250   // Only support converting to float for now.
251   if (!dstType.isF32())
252     return FloatAttr();
253 
254   // Try to convert the source floating-point number to single precision.
255   APFloat dstVal = srcAttr.getValue();
256   bool losesInfo = false;
257   APFloat::opStatus status =
258       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
259   if (status != APFloat::opOK || losesInfo) {
260     LLVM_DEBUG(llvm::dbgs()
261                << srcAttr << " illegal: cannot fit into converted type '"
262                << dstType << "'\n");
263     return FloatAttr();
264   }
265 
266   return builder.getF32FloatAttr(dstVal.convertToFloat());
267 }
268 
269 /// Returns true if the given `type` is a boolean scalar or vector type.
270 static bool isBoolScalarOrVector(Type type) {
271   if (type.isInteger(1))
272     return true;
273   if (auto vecType = type.dyn_cast<VectorType>())
274     return vecType.getElementType().isInteger(1);
275   return false;
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // ConstantOp with composite type
280 //===----------------------------------------------------------------------===//
281 
282 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
283     arith::ConstantOp constOp, OpAdaptor adaptor,
284     ConversionPatternRewriter &rewriter) const {
285   auto srcType = constOp.getType().dyn_cast<ShapedType>();
286   if (!srcType || srcType.getNumElements() == 1)
287     return failure();
288 
289   // arith.constant should only have vector or tenor types.
290   assert((srcType.isa<VectorType, RankedTensorType>()));
291 
292   auto dstType = getTypeConverter()->convertType(srcType);
293   if (!dstType)
294     return failure();
295 
296   auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
297   ShapedType dstAttrType = dstElementsAttr.getType();
298   if (!dstElementsAttr)
299     return failure();
300 
301   // If the composite type has more than one dimensions, perform linearization.
302   if (srcType.getRank() > 1) {
303     if (srcType.isa<RankedTensorType>()) {
304       dstAttrType = RankedTensorType::get(srcType.getNumElements(),
305                                           srcType.getElementType());
306       dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
307     } else {
308       // TODO: add support for large vectors.
309       return failure();
310     }
311   }
312 
313   Type srcElemType = srcType.getElementType();
314   Type dstElemType;
315   // Tensor types are converted to SPIR-V array types; vector types are
316   // converted to SPIR-V vector/array types.
317   if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
318     dstElemType = arrayType.getElementType();
319   else
320     dstElemType = dstType.cast<VectorType>().getElementType();
321 
322   // If the source and destination element types are different, perform
323   // attribute conversion.
324   if (srcElemType != dstElemType) {
325     SmallVector<Attribute, 8> elements;
326     if (srcElemType.isa<FloatType>()) {
327       for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
328         FloatAttr dstAttr =
329             convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
330         if (!dstAttr)
331           return failure();
332         elements.push_back(dstAttr);
333       }
334     } else if (srcElemType.isInteger(1)) {
335       return failure();
336     } else {
337       for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
338         IntegerAttr dstAttr = convertIntegerAttr(
339             srcAttr, dstElemType.cast<IntegerType>(), rewriter);
340         if (!dstAttr)
341           return failure();
342         elements.push_back(dstAttr);
343       }
344     }
345 
346     // Unfortunately, we cannot use dialect-specific types for element
347     // attributes; element attributes only works with builtin types. So we need
348     // to prepare another converted builtin types for the destination elements
349     // attribute.
350     if (dstAttrType.isa<RankedTensorType>())
351       dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
352     else
353       dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
354 
355     dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
356   }
357 
358   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
359                                                  dstElementsAttr);
360   return success();
361 }
362 
363 //===----------------------------------------------------------------------===//
364 // ConstantOp with scalar type
365 //===----------------------------------------------------------------------===//
366 
367 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
368     arith::ConstantOp constOp, OpAdaptor adaptor,
369     ConversionPatternRewriter &rewriter) const {
370   Type srcType = constOp.getType();
371   if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
372     if (shapedType.getNumElements() != 1)
373       return failure();
374     srcType = shapedType.getElementType();
375   }
376   if (!srcType.isIntOrIndexOrFloat())
377     return failure();
378 
379   Attribute cstAttr = constOp.getValue();
380   if (cstAttr.getType().isa<ShapedType>())
381     cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
382 
383   Type dstType = getTypeConverter()->convertType(srcType);
384   if (!dstType)
385     return failure();
386 
387   // Floating-point types.
388   if (srcType.isa<FloatType>()) {
389     auto srcAttr = cstAttr.cast<FloatAttr>();
390     auto dstAttr = srcAttr;
391 
392     // Floating-point types not supported in the target environment are all
393     // converted to float type.
394     if (srcType != dstType) {
395       dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
396       if (!dstAttr)
397         return failure();
398     }
399 
400     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
401     return success();
402   }
403 
404   // Bool type.
405   if (srcType.isInteger(1)) {
406     // arith.constant can use 0/1 instead of true/false for i1 values. We need
407     // to handle that here.
408     auto dstAttr = convertBoolAttr(cstAttr, rewriter);
409     if (!dstAttr)
410       return failure();
411     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
412     return success();
413   }
414 
415   // IndexType or IntegerType. Index values are converted to 32-bit integer
416   // values when converting to SPIR-V.
417   auto srcAttr = cstAttr.cast<IntegerAttr>();
418   auto dstAttr =
419       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
420   if (!dstAttr)
421     return failure();
422   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
423   return success();
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // RemSIOpGLSLPattern
428 //===----------------------------------------------------------------------===//
429 
430 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
431 /// the sign of `signOperand`.
432 ///
433 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
434 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
435 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
436 /// if either operand can be negative. Emulate it via spv.UMod.
437 template <typename SignedAbsOp>
438 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
439                                     Value signOperand, OpBuilder &builder) {
440   assert(lhs.getType() == rhs.getType());
441   assert(lhs == signOperand || rhs == signOperand);
442 
443   Type type = lhs.getType();
444 
445   // Calculate the remainder with spv.UMod.
446   Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
447   Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
448   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
449 
450   // Fix the sign.
451   Value isPositive;
452   if (lhs == signOperand)
453     isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
454   else
455     isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
456   Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
457   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
458 }
459 
460 LogicalResult
461 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
462                                     ConversionPatternRewriter &rewriter) const {
463   Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>(
464       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
465       adaptor.getOperands()[0], rewriter);
466   rewriter.replaceOp(op, result);
467 
468   return success();
469 }
470 
471 //===----------------------------------------------------------------------===//
472 // RemSIOpOCLPattern
473 //===----------------------------------------------------------------------===//
474 
475 LogicalResult
476 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
477                                    ConversionPatternRewriter &rewriter) const {
478   Value result = emulateSignedRemainder<spirv::OCLSAbsOp>(
479       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
480       adaptor.getOperands()[0], rewriter);
481   rewriter.replaceOp(op, result);
482 
483   return success();
484 }
485 
486 //===----------------------------------------------------------------------===//
487 // BitwiseOpPattern
488 //===----------------------------------------------------------------------===//
489 
490 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
491 LogicalResult
492 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
493     Op op, typename Op::Adaptor adaptor,
494     ConversionPatternRewriter &rewriter) const {
495   assert(adaptor.getOperands().size() == 2);
496   auto dstType =
497       this->getTypeConverter()->convertType(op.getResult().getType());
498   if (!dstType)
499     return failure();
500   if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
501     rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
502                                                          adaptor.getOperands());
503   } else {
504     rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
505                                                          adaptor.getOperands());
506   }
507   return success();
508 }
509 
510 //===----------------------------------------------------------------------===//
511 // XOrIOpLogicalPattern
512 //===----------------------------------------------------------------------===//
513 
514 LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
515     arith::XOrIOp op, OpAdaptor adaptor,
516     ConversionPatternRewriter &rewriter) const {
517   assert(adaptor.getOperands().size() == 2);
518 
519   if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
520     return failure();
521 
522   auto dstType = getTypeConverter()->convertType(op.getType());
523   if (!dstType)
524     return failure();
525   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
526                                                    adaptor.getOperands());
527 
528   return success();
529 }
530 
531 //===----------------------------------------------------------------------===//
532 // XOrIOpBooleanPattern
533 //===----------------------------------------------------------------------===//
534 
535 LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
536     arith::XOrIOp op, OpAdaptor adaptor,
537     ConversionPatternRewriter &rewriter) const {
538   assert(adaptor.getOperands().size() == 2);
539 
540   if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
541     return failure();
542 
543   auto dstType = getTypeConverter()->convertType(op.getType());
544   if (!dstType)
545     return failure();
546   rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
547                                                         adaptor.getOperands());
548   return success();
549 }
550 
551 //===----------------------------------------------------------------------===//
552 // UIToFPI1Pattern
553 //===----------------------------------------------------------------------===//
554 
555 LogicalResult
556 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
557                                  ConversionPatternRewriter &rewriter) const {
558   auto srcType = adaptor.getOperands().front().getType();
559   if (!isBoolScalarOrVector(srcType))
560     return failure();
561 
562   auto dstType =
563       this->getTypeConverter()->convertType(op.getResult().getType());
564   Location loc = op.getLoc();
565   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
566   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
567   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
568       op, dstType, adaptor.getOperands().front(), one, zero);
569   return success();
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // ExtUII1Pattern
574 //===----------------------------------------------------------------------===//
575 
576 LogicalResult
577 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
578                                 ConversionPatternRewriter &rewriter) const {
579   auto srcType = adaptor.getOperands().front().getType();
580   if (!isBoolScalarOrVector(srcType))
581     return failure();
582 
583   auto dstType =
584       this->getTypeConverter()->convertType(op.getResult().getType());
585   Location loc = op.getLoc();
586   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
587   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
588   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
589       op, dstType, adaptor.getOperands().front(), one, zero);
590   return success();
591 }
592 
593 //===----------------------------------------------------------------------===//
594 // TruncII1Pattern
595 //===----------------------------------------------------------------------===//
596 
597 LogicalResult
598 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
599                                  ConversionPatternRewriter &rewriter) const {
600   auto dstType =
601       this->getTypeConverter()->convertType(op.getResult().getType());
602   if (!isBoolScalarOrVector(dstType))
603     return failure();
604 
605   Location loc = op.getLoc();
606   auto srcType = adaptor.getOperands().front().getType();
607   // Check if (x & 1) == 1.
608   Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
609   Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
610       loc, srcType, adaptor.getOperands()[0], mask);
611   Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
612 
613   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
614   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
615   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
616   return success();
617 }
618 
619 //===----------------------------------------------------------------------===//
620 // TypeCastingOpPattern
621 //===----------------------------------------------------------------------===//
622 
623 template <typename Op, typename SPIRVOp>
624 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
625     Op op, typename Op::Adaptor adaptor,
626     ConversionPatternRewriter &rewriter) const {
627   assert(adaptor.getOperands().size() == 1);
628   auto srcType = adaptor.getOperands().front().getType();
629   auto dstType =
630       this->getTypeConverter()->convertType(op.getResult().getType());
631   if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
632     return failure();
633   if (dstType == srcType) {
634     // Due to type conversion, we are seeing the same source and target type.
635     // Then we can just erase this operation by forwarding its operand.
636     rewriter.replaceOp(op, adaptor.getOperands().front());
637   } else {
638     rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
639                                                   adaptor.getOperands());
640   }
641   return success();
642 }
643 
644 //===----------------------------------------------------------------------===//
645 // CmpIOpBooleanPattern
646 //===----------------------------------------------------------------------===//
647 
648 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
649     arith::CmpIOp op, OpAdaptor adaptor,
650     ConversionPatternRewriter &rewriter) const {
651   Type operandType = op.getLhs().getType();
652   if (!isBoolScalarOrVector(operandType))
653     return failure();
654 
655   switch (op.getPredicate()) {
656 #define DISPATCH(cmpPredicate, spirvOp)                                        \
657   case cmpPredicate:                                                           \
658     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
659                                          adaptor.getLhs(), adaptor.getRhs());  \
660     return success();
661 
662     DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
663     DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
664 
665 #undef DISPATCH
666   default:;
667   }
668   return failure();
669 }
670 
671 //===----------------------------------------------------------------------===//
672 // CmpIOpPattern
673 //===----------------------------------------------------------------------===//
674 
675 LogicalResult
676 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
677                                ConversionPatternRewriter &rewriter) const {
678   Type operandType = op.getLhs().getType();
679   if (isBoolScalarOrVector(operandType))
680     return failure();
681 
682   switch (op.getPredicate()) {
683 #define DISPATCH(cmpPredicate, spirvOp)                                        \
684   case cmpPredicate:                                                           \
685     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
686         operandType != this->getTypeConverter()->convertType(operandType)) {   \
687       return op.emitError(                                                     \
688           "bitwidth emulation is not implemented yet on unsigned op");         \
689     }                                                                          \
690     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
691                                          adaptor.getLhs(), adaptor.getRhs());  \
692     return success();
693 
694     DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
695     DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
696     DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
697     DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
698     DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
699     DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
700     DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
701     DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
702     DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
703     DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
704 
705 #undef DISPATCH
706   }
707   return failure();
708 }
709 
710 //===----------------------------------------------------------------------===//
711 // CmpFOpPattern
712 //===----------------------------------------------------------------------===//
713 
714 LogicalResult
715 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
716                                ConversionPatternRewriter &rewriter) const {
717   switch (op.getPredicate()) {
718 #define DISPATCH(cmpPredicate, spirvOp)                                        \
719   case cmpPredicate:                                                           \
720     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
721                                          adaptor.getLhs(), adaptor.getRhs());  \
722     return success();
723 
724     // Ordered.
725     DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
726     DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
727     DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
728     DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
729     DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
730     DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
731     // Unordered.
732     DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
733     DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
734     DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
735     DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
736     DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
737     DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
738 
739 #undef DISPATCH
740 
741   default:
742     break;
743   }
744   return failure();
745 }
746 
747 //===----------------------------------------------------------------------===//
748 // CmpFOpNanKernelPattern
749 //===----------------------------------------------------------------------===//
750 
751 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
752     arith::CmpFOp op, OpAdaptor adaptor,
753     ConversionPatternRewriter &rewriter) const {
754   if (op.getPredicate() == arith::CmpFPredicate::ORD) {
755     rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
756                                                   adaptor.getRhs());
757     return success();
758   }
759 
760   if (op.getPredicate() == arith::CmpFPredicate::UNO) {
761     rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
762                                                     adaptor.getRhs());
763     return success();
764   }
765 
766   return failure();
767 }
768 
769 //===----------------------------------------------------------------------===//
770 // CmpFOpNanNonePattern
771 //===----------------------------------------------------------------------===//
772 
773 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
774     arith::CmpFOp op, OpAdaptor adaptor,
775     ConversionPatternRewriter &rewriter) const {
776   if (op.getPredicate() != arith::CmpFPredicate::ORD &&
777       op.getPredicate() != arith::CmpFPredicate::UNO)
778     return failure();
779 
780   Location loc = op.getLoc();
781 
782   Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
783   Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
784 
785   Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
786   if (op.getPredicate() == arith::CmpFPredicate::ORD)
787     replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
788 
789   rewriter.replaceOp(op, replace);
790   return success();
791 }
792 
793 //===----------------------------------------------------------------------===//
794 // SelectOpPattern
795 //===----------------------------------------------------------------------===//
796 
797 LogicalResult
798 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
799                                  ConversionPatternRewriter &rewriter) const {
800   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
801                                                adaptor.getTrueValue(),
802                                                adaptor.getFalseValue());
803   return success();
804 }
805 
806 //===----------------------------------------------------------------------===//
807 // Pattern Population
808 //===----------------------------------------------------------------------===//
809 
810 void mlir::arith::populateArithmeticToSPIRVPatterns(
811     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
812   // clang-format off
813   patterns.add<
814     ConstantCompositeOpPattern,
815     ConstantScalarOpPattern,
816     spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
817     spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
818     spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
819     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
820     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
821     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
822     RemSIOpGLSLPattern, RemSIOpOCLPattern,
823     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
824     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
825     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
826     spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
827     spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
828     spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
829     spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
830     spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
831     spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
832     spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
833     spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
834     spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
835     TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
836     TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
837     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
838     TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
839     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
840     TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
841     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
842     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
843     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
844     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
845     CmpIOpBooleanPattern, CmpIOpPattern,
846     CmpFOpNanNonePattern, CmpFOpPattern,
847     SelectOpPattern,
848 
849     spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
850     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
851     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
852     spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
853     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
854     spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>
855   >(typeConverter, patterns.getContext());
856   // clang-format on
857 
858   // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
859   // capability is available.
860   patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
861                                        /*benefit=*/2);
862 }
863 
864 //===----------------------------------------------------------------------===//
865 // Pass Definition
866 //===----------------------------------------------------------------------===//
867 
868 namespace {
869 struct ConvertArithmeticToSPIRVPass
870     : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
871   void runOnOperation() override {
872     auto module = getOperation();
873     auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
874     auto target = SPIRVConversionTarget::get(targetAttr);
875 
876     SPIRVTypeConverter::Options options;
877     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
878     SPIRVTypeConverter typeConverter(targetAttr, options);
879 
880     RewritePatternSet patterns(&getContext());
881     arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
882     populateFuncToSPIRVPatterns(typeConverter, patterns);
883     populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
884 
885     if (failed(applyPartialConversion(module, *target, std::move(patterns))))
886       signalPassFailure();
887   }
888 };
889 } // namespace
890 
891 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
892   return std::make_unique<ConvertArithmeticToSPIRVPass>();
893 }
894