1 //===- ArithmeticToSPIRV.cpp - Arithmetic to SPIRV dialect conversion -----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Conversion/ArithmeticToSPIRV/ArithmeticToSPIRV.h"
10 #include "../PassDetail.h"
11 #include "../SPIRVCommon/Pattern.h"
12 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h"
13 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
15 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
16 #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
17 #include "llvm/Support/Debug.h"
18 
19 #define DEBUG_TYPE "arith-to-spirv-pattern"
20 
21 using namespace mlir;
22 
23 //===----------------------------------------------------------------------===//
24 // Operation Conversion
25 //===----------------------------------------------------------------------===//
26 
27 namespace {
28 
29 /// Converts composite arith.constant operation to spv.Constant.
30 struct ConstantCompositeOpPattern final
31     : public OpConversionPattern<arith::ConstantOp> {
32   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
33 
34   LogicalResult
35   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
36                   ConversionPatternRewriter &rewriter) const override;
37 };
38 
39 /// Converts scalar arith.constant operation to spv.Constant.
40 struct ConstantScalarOpPattern final
41     : public OpConversionPattern<arith::ConstantOp> {
42   using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
43 
44   LogicalResult
45   matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
46                   ConversionPatternRewriter &rewriter) const override;
47 };
48 
49 /// Converts arith.remsi to GLSL SPIR-V ops.
50 ///
51 /// This cannot be merged into the template unary/binary pattern due to Vulkan
52 /// restrictions over spv.SRem and spv.SMod.
53 struct RemSIOpGLSLPattern final : public OpConversionPattern<arith::RemSIOp> {
54   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
55 
56   LogicalResult
57   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
58                   ConversionPatternRewriter &rewriter) const override;
59 };
60 
61 /// Converts arith.remsi to OpenCL SPIR-V ops.
62 struct RemSIOpOCLPattern final : public OpConversionPattern<arith::RemSIOp> {
63   using OpConversionPattern<arith::RemSIOp>::OpConversionPattern;
64 
65   LogicalResult
66   matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
67                   ConversionPatternRewriter &rewriter) const override;
68 };
69 
70 /// Converts bitwise operations to SPIR-V operations. This is a special pattern
71 /// other than the BinaryOpPatternPattern because if the operands are boolean
72 /// values, SPIR-V uses different operations (`SPIRVLogicalOp`). For
73 /// non-boolean operands, SPIR-V should use `SPIRVBitwiseOp`.
74 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
75 struct BitwiseOpPattern final : public OpConversionPattern<Op> {
76   using OpConversionPattern<Op>::OpConversionPattern;
77 
78   LogicalResult
79   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
80                   ConversionPatternRewriter &rewriter) const override;
81 };
82 
83 /// Converts arith.xori to SPIR-V operations.
84 struct XOrIOpLogicalPattern final : public OpConversionPattern<arith::XOrIOp> {
85   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
86 
87   LogicalResult
88   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
89                   ConversionPatternRewriter &rewriter) const override;
90 };
91 
92 /// Converts arith.xori to SPIR-V operations if the type of source is i1 or
93 /// vector of i1.
94 struct XOrIOpBooleanPattern final : public OpConversionPattern<arith::XOrIOp> {
95   using OpConversionPattern<arith::XOrIOp>::OpConversionPattern;
96 
97   LogicalResult
98   matchAndRewrite(arith::XOrIOp op, OpAdaptor adaptor,
99                   ConversionPatternRewriter &rewriter) const override;
100 };
101 
102 /// Converts arith.uitofp to spv.Select if the type of source is i1 or vector of
103 /// i1.
104 struct UIToFPI1Pattern final : public OpConversionPattern<arith::UIToFPOp> {
105   using OpConversionPattern<arith::UIToFPOp>::OpConversionPattern;
106 
107   LogicalResult
108   matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
109                   ConversionPatternRewriter &rewriter) const override;
110 };
111 
112 /// Converts arith.extui to spv.Select if the type of source is i1 or vector of
113 /// i1.
114 struct ExtUII1Pattern final : public OpConversionPattern<arith::ExtUIOp> {
115   using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
116 
117   LogicalResult
118   matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
119                   ConversionPatternRewriter &rewriter) const override;
120 };
121 
122 /// Converts arith.trunci to spv.Select if the type of result is i1 or vector of
123 /// i1.
124 struct TruncII1Pattern final : public OpConversionPattern<arith::TruncIOp> {
125   using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
126 
127   LogicalResult
128   matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
129                   ConversionPatternRewriter &rewriter) const override;
130 };
131 
132 /// Converts type-casting standard operations to SPIR-V operations.
133 template <typename Op, typename SPIRVOp>
134 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
135   using OpConversionPattern<Op>::OpConversionPattern;
136 
137   LogicalResult
138   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
139                   ConversionPatternRewriter &rewriter) const override;
140 };
141 
142 /// Converts integer compare operation on i1 type operands to SPIR-V ops.
143 class CmpIOpBooleanPattern final : public OpConversionPattern<arith::CmpIOp> {
144 public:
145   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
146 
147   LogicalResult
148   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
149                   ConversionPatternRewriter &rewriter) const override;
150 };
151 
152 /// Converts integer compare operation to SPIR-V ops.
153 class CmpIOpPattern final : public OpConversionPattern<arith::CmpIOp> {
154 public:
155   using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
156 
157   LogicalResult
158   matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
159                   ConversionPatternRewriter &rewriter) const override;
160 };
161 
162 /// Converts floating-point comparison operations to SPIR-V ops.
163 class CmpFOpPattern final : public OpConversionPattern<arith::CmpFOp> {
164 public:
165   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
166 
167   LogicalResult
168   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
169                   ConversionPatternRewriter &rewriter) const override;
170 };
171 
172 /// Converts floating point NaN check to SPIR-V ops. This pattern requires
173 /// Kernel capability.
174 class CmpFOpNanKernelPattern final : public OpConversionPattern<arith::CmpFOp> {
175 public:
176   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
177 
178   LogicalResult
179   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
180                   ConversionPatternRewriter &rewriter) const override;
181 };
182 
183 /// Converts floating point NaN check to SPIR-V ops. This pattern does not
184 /// require additional capability.
185 class CmpFOpNanNonePattern final : public OpConversionPattern<arith::CmpFOp> {
186 public:
187   using OpConversionPattern<arith::CmpFOp>::OpConversionPattern;
188 
189   LogicalResult
190   matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
191                   ConversionPatternRewriter &rewriter) const override;
192 };
193 
194 /// Converts arith.select to spv.Select.
195 class SelectOpPattern final : public OpConversionPattern<arith::SelectOp> {
196 public:
197   using OpConversionPattern<arith::SelectOp>::OpConversionPattern;
198   LogicalResult
199   matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
200                   ConversionPatternRewriter &rewriter) const override;
201 };
202 
203 } // namespace
204 
205 //===----------------------------------------------------------------------===//
206 // Conversion Helpers
207 //===----------------------------------------------------------------------===//
208 
209 /// Converts the given `srcAttr` into a boolean attribute if it holds an
210 /// integral value. Returns null attribute if conversion fails.
211 static BoolAttr convertBoolAttr(Attribute srcAttr, Builder builder) {
212   if (auto boolAttr = srcAttr.dyn_cast<BoolAttr>())
213     return boolAttr;
214   if (auto intAttr = srcAttr.dyn_cast<IntegerAttr>())
215     return builder.getBoolAttr(intAttr.getValue().getBoolValue());
216   return BoolAttr();
217 }
218 
219 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
220 /// Returns null attribute if conversion fails.
221 static IntegerAttr convertIntegerAttr(IntegerAttr srcAttr, IntegerType dstType,
222                                       Builder builder) {
223   // If the source number uses less active bits than the target bitwidth, then
224   // it should be safe to convert.
225   if (srcAttr.getValue().isIntN(dstType.getWidth()))
226     return builder.getIntegerAttr(dstType, srcAttr.getInt());
227 
228   // XXX: Try again by interpreting the source number as a signed value.
229   // Although integers in the standard dialect are signless, they can represent
230   // a signed number. It's the operation decides how to interpret. This is
231   // dangerous, but it seems there is no good way of handling this if we still
232   // want to change the bitwidth. Emit a message at least.
233   if (srcAttr.getValue().isSignedIntN(dstType.getWidth())) {
234     auto dstAttr = builder.getIntegerAttr(dstType, srcAttr.getInt());
235     LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr << "' converted to '"
236                             << dstAttr << "' for type '" << dstType << "'\n");
237     return dstAttr;
238   }
239 
240   LLVM_DEBUG(llvm::dbgs() << "attribute '" << srcAttr
241                           << "' illegal: cannot fit into target type '"
242                           << dstType << "'\n");
243   return IntegerAttr();
244 }
245 
246 /// Converts the given `srcAttr` to a new attribute of the given `dstType`.
247 /// Returns null attribute if `dstType` is not 32-bit or conversion fails.
248 static FloatAttr convertFloatAttr(FloatAttr srcAttr, FloatType dstType,
249                                   Builder builder) {
250   // Only support converting to float for now.
251   if (!dstType.isF32())
252     return FloatAttr();
253 
254   // Try to convert the source floating-point number to single precision.
255   APFloat dstVal = srcAttr.getValue();
256   bool losesInfo = false;
257   APFloat::opStatus status =
258       dstVal.convert(APFloat::IEEEsingle(), APFloat::rmTowardZero, &losesInfo);
259   if (status != APFloat::opOK || losesInfo) {
260     LLVM_DEBUG(llvm::dbgs()
261                << srcAttr << " illegal: cannot fit into converted type '"
262                << dstType << "'\n");
263     return FloatAttr();
264   }
265 
266   return builder.getF32FloatAttr(dstVal.convertToFloat());
267 }
268 
269 /// Returns true if the given `type` is a boolean scalar or vector type.
270 static bool isBoolScalarOrVector(Type type) {
271   if (type.isInteger(1))
272     return true;
273   if (auto vecType = type.dyn_cast<VectorType>())
274     return vecType.getElementType().isInteger(1);
275   return false;
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // ConstantOp with composite type
280 //===----------------------------------------------------------------------===//
281 
282 LogicalResult ConstantCompositeOpPattern::matchAndRewrite(
283     arith::ConstantOp constOp, OpAdaptor adaptor,
284     ConversionPatternRewriter &rewriter) const {
285   auto srcType = constOp.getType().dyn_cast<ShapedType>();
286   if (!srcType || srcType.getNumElements() == 1)
287     return failure();
288 
289   // arith.constant should only have vector or tenor types.
290   assert((srcType.isa<VectorType, RankedTensorType>()));
291 
292   auto dstType = getTypeConverter()->convertType(srcType);
293   if (!dstType)
294     return failure();
295 
296   auto dstElementsAttr = constOp.getValue().dyn_cast<DenseElementsAttr>();
297   if (!dstElementsAttr)
298     return failure();
299 
300   ShapedType dstAttrType = dstElementsAttr.getType();
301 
302   // If the composite type has more than one dimensions, perform linearization.
303   if (srcType.getRank() > 1) {
304     if (srcType.isa<RankedTensorType>()) {
305       dstAttrType = RankedTensorType::get(srcType.getNumElements(),
306                                           srcType.getElementType());
307       dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
308     } else {
309       // TODO: add support for large vectors.
310       return failure();
311     }
312   }
313 
314   Type srcElemType = srcType.getElementType();
315   Type dstElemType;
316   // Tensor types are converted to SPIR-V array types; vector types are
317   // converted to SPIR-V vector/array types.
318   if (auto arrayType = dstType.dyn_cast<spirv::ArrayType>())
319     dstElemType = arrayType.getElementType();
320   else
321     dstElemType = dstType.cast<VectorType>().getElementType();
322 
323   // If the source and destination element types are different, perform
324   // attribute conversion.
325   if (srcElemType != dstElemType) {
326     SmallVector<Attribute, 8> elements;
327     if (srcElemType.isa<FloatType>()) {
328       for (FloatAttr srcAttr : dstElementsAttr.getValues<FloatAttr>()) {
329         FloatAttr dstAttr =
330             convertFloatAttr(srcAttr, dstElemType.cast<FloatType>(), rewriter);
331         if (!dstAttr)
332           return failure();
333         elements.push_back(dstAttr);
334       }
335     } else if (srcElemType.isInteger(1)) {
336       return failure();
337     } else {
338       for (IntegerAttr srcAttr : dstElementsAttr.getValues<IntegerAttr>()) {
339         IntegerAttr dstAttr = convertIntegerAttr(
340             srcAttr, dstElemType.cast<IntegerType>(), rewriter);
341         if (!dstAttr)
342           return failure();
343         elements.push_back(dstAttr);
344       }
345     }
346 
347     // Unfortunately, we cannot use dialect-specific types for element
348     // attributes; element attributes only works with builtin types. So we need
349     // to prepare another converted builtin types for the destination elements
350     // attribute.
351     if (dstAttrType.isa<RankedTensorType>())
352       dstAttrType = RankedTensorType::get(dstAttrType.getShape(), dstElemType);
353     else
354       dstAttrType = VectorType::get(dstAttrType.getShape(), dstElemType);
355 
356     dstElementsAttr = DenseElementsAttr::get(dstAttrType, elements);
357   }
358 
359   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType,
360                                                  dstElementsAttr);
361   return success();
362 }
363 
364 //===----------------------------------------------------------------------===//
365 // ConstantOp with scalar type
366 //===----------------------------------------------------------------------===//
367 
368 LogicalResult ConstantScalarOpPattern::matchAndRewrite(
369     arith::ConstantOp constOp, OpAdaptor adaptor,
370     ConversionPatternRewriter &rewriter) const {
371   Type srcType = constOp.getType();
372   if (auto shapedType = srcType.dyn_cast<ShapedType>()) {
373     if (shapedType.getNumElements() != 1)
374       return failure();
375     srcType = shapedType.getElementType();
376   }
377   if (!srcType.isIntOrIndexOrFloat())
378     return failure();
379 
380   Attribute cstAttr = constOp.getValue();
381   if (cstAttr.getType().isa<ShapedType>())
382     cstAttr = cstAttr.cast<DenseElementsAttr>().getSplatValue<Attribute>();
383 
384   Type dstType = getTypeConverter()->convertType(srcType);
385   if (!dstType)
386     return failure();
387 
388   // Floating-point types.
389   if (srcType.isa<FloatType>()) {
390     auto srcAttr = cstAttr.cast<FloatAttr>();
391     auto dstAttr = srcAttr;
392 
393     // Floating-point types not supported in the target environment are all
394     // converted to float type.
395     if (srcType != dstType) {
396       dstAttr = convertFloatAttr(srcAttr, dstType.cast<FloatType>(), rewriter);
397       if (!dstAttr)
398         return failure();
399     }
400 
401     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
402     return success();
403   }
404 
405   // Bool type.
406   if (srcType.isInteger(1)) {
407     // arith.constant can use 0/1 instead of true/false for i1 values. We need
408     // to handle that here.
409     auto dstAttr = convertBoolAttr(cstAttr, rewriter);
410     if (!dstAttr)
411       return failure();
412     rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
413     return success();
414   }
415 
416   // IndexType or IntegerType. Index values are converted to 32-bit integer
417   // values when converting to SPIR-V.
418   auto srcAttr = cstAttr.cast<IntegerAttr>();
419   auto dstAttr =
420       convertIntegerAttr(srcAttr, dstType.cast<IntegerType>(), rewriter);
421   if (!dstAttr)
422     return failure();
423   rewriter.replaceOpWithNewOp<spirv::ConstantOp>(constOp, dstType, dstAttr);
424   return success();
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // RemSIOpGLSLPattern
429 //===----------------------------------------------------------------------===//
430 
431 /// Returns signed remainder for `lhs` and `rhs` and lets the result follow
432 /// the sign of `signOperand`.
433 ///
434 /// Note that this is needed for Vulkan. Per the Vulkan's SPIR-V environment
435 /// spec, "for the OpSRem and OpSMod instructions, if either operand is negative
436 /// the result is undefined."  So we cannot directly use spv.SRem/spv.SMod
437 /// if either operand can be negative. Emulate it via spv.UMod.
438 template <typename SignedAbsOp>
439 static Value emulateSignedRemainder(Location loc, Value lhs, Value rhs,
440                                     Value signOperand, OpBuilder &builder) {
441   assert(lhs.getType() == rhs.getType());
442   assert(lhs == signOperand || rhs == signOperand);
443 
444   Type type = lhs.getType();
445 
446   // Calculate the remainder with spv.UMod.
447   Value lhsAbs = builder.create<SignedAbsOp>(loc, type, lhs);
448   Value rhsAbs = builder.create<SignedAbsOp>(loc, type, rhs);
449   Value abs = builder.create<spirv::UModOp>(loc, lhsAbs, rhsAbs);
450 
451   // Fix the sign.
452   Value isPositive;
453   if (lhs == signOperand)
454     isPositive = builder.create<spirv::IEqualOp>(loc, lhs, lhsAbs);
455   else
456     isPositive = builder.create<spirv::IEqualOp>(loc, rhs, rhsAbs);
457   Value absNegate = builder.create<spirv::SNegateOp>(loc, type, abs);
458   return builder.create<spirv::SelectOp>(loc, type, isPositive, abs, absNegate);
459 }
460 
461 LogicalResult
462 RemSIOpGLSLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
463                                     ConversionPatternRewriter &rewriter) const {
464   Value result = emulateSignedRemainder<spirv::GLSLSAbsOp>(
465       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
466       adaptor.getOperands()[0], rewriter);
467   rewriter.replaceOp(op, result);
468 
469   return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // RemSIOpOCLPattern
474 //===----------------------------------------------------------------------===//
475 
476 LogicalResult
477 RemSIOpOCLPattern::matchAndRewrite(arith::RemSIOp op, OpAdaptor adaptor,
478                                    ConversionPatternRewriter &rewriter) const {
479   Value result = emulateSignedRemainder<spirv::OCLSAbsOp>(
480       op.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1],
481       adaptor.getOperands()[0], rewriter);
482   rewriter.replaceOp(op, result);
483 
484   return success();
485 }
486 
487 //===----------------------------------------------------------------------===//
488 // BitwiseOpPattern
489 //===----------------------------------------------------------------------===//
490 
491 template <typename Op, typename SPIRVLogicalOp, typename SPIRVBitwiseOp>
492 LogicalResult
493 BitwiseOpPattern<Op, SPIRVLogicalOp, SPIRVBitwiseOp>::matchAndRewrite(
494     Op op, typename Op::Adaptor adaptor,
495     ConversionPatternRewriter &rewriter) const {
496   assert(adaptor.getOperands().size() == 2);
497   auto dstType =
498       this->getTypeConverter()->convertType(op.getResult().getType());
499   if (!dstType)
500     return failure();
501   if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) {
502     rewriter.template replaceOpWithNewOp<SPIRVLogicalOp>(op, dstType,
503                                                          adaptor.getOperands());
504   } else {
505     rewriter.template replaceOpWithNewOp<SPIRVBitwiseOp>(op, dstType,
506                                                          adaptor.getOperands());
507   }
508   return success();
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // XOrIOpLogicalPattern
513 //===----------------------------------------------------------------------===//
514 
515 LogicalResult XOrIOpLogicalPattern::matchAndRewrite(
516     arith::XOrIOp op, OpAdaptor adaptor,
517     ConversionPatternRewriter &rewriter) const {
518   assert(adaptor.getOperands().size() == 2);
519 
520   if (isBoolScalarOrVector(adaptor.getOperands().front().getType()))
521     return failure();
522 
523   auto dstType = getTypeConverter()->convertType(op.getType());
524   if (!dstType)
525     return failure();
526   rewriter.replaceOpWithNewOp<spirv::BitwiseXorOp>(op, dstType,
527                                                    adaptor.getOperands());
528 
529   return success();
530 }
531 
532 //===----------------------------------------------------------------------===//
533 // XOrIOpBooleanPattern
534 //===----------------------------------------------------------------------===//
535 
536 LogicalResult XOrIOpBooleanPattern::matchAndRewrite(
537     arith::XOrIOp op, OpAdaptor adaptor,
538     ConversionPatternRewriter &rewriter) const {
539   assert(adaptor.getOperands().size() == 2);
540 
541   if (!isBoolScalarOrVector(adaptor.getOperands().front().getType()))
542     return failure();
543 
544   auto dstType = getTypeConverter()->convertType(op.getType());
545   if (!dstType)
546     return failure();
547   rewriter.replaceOpWithNewOp<spirv::LogicalNotEqualOp>(op, dstType,
548                                                         adaptor.getOperands());
549   return success();
550 }
551 
552 //===----------------------------------------------------------------------===//
553 // UIToFPI1Pattern
554 //===----------------------------------------------------------------------===//
555 
556 LogicalResult
557 UIToFPI1Pattern::matchAndRewrite(arith::UIToFPOp op, OpAdaptor adaptor,
558                                  ConversionPatternRewriter &rewriter) const {
559   auto srcType = adaptor.getOperands().front().getType();
560   if (!isBoolScalarOrVector(srcType))
561     return failure();
562 
563   auto dstType =
564       this->getTypeConverter()->convertType(op.getResult().getType());
565   Location loc = op.getLoc();
566   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
567   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
568   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
569       op, dstType, adaptor.getOperands().front(), one, zero);
570   return success();
571 }
572 
573 //===----------------------------------------------------------------------===//
574 // ExtUII1Pattern
575 //===----------------------------------------------------------------------===//
576 
577 LogicalResult
578 ExtUII1Pattern::matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
579                                 ConversionPatternRewriter &rewriter) const {
580   auto srcType = adaptor.getOperands().front().getType();
581   if (!isBoolScalarOrVector(srcType))
582     return failure();
583 
584   auto dstType =
585       this->getTypeConverter()->convertType(op.getResult().getType());
586   Location loc = op.getLoc();
587   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
588   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
589   rewriter.template replaceOpWithNewOp<spirv::SelectOp>(
590       op, dstType, adaptor.getOperands().front(), one, zero);
591   return success();
592 }
593 
594 //===----------------------------------------------------------------------===//
595 // TruncII1Pattern
596 //===----------------------------------------------------------------------===//
597 
598 LogicalResult
599 TruncII1Pattern::matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
600                                  ConversionPatternRewriter &rewriter) const {
601   auto dstType =
602       this->getTypeConverter()->convertType(op.getResult().getType());
603   if (!isBoolScalarOrVector(dstType))
604     return failure();
605 
606   Location loc = op.getLoc();
607   auto srcType = adaptor.getOperands().front().getType();
608   // Check if (x & 1) == 1.
609   Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter);
610   Value maskedSrc = rewriter.create<spirv::BitwiseAndOp>(
611       loc, srcType, adaptor.getOperands()[0], mask);
612   Value isOne = rewriter.create<spirv::IEqualOp>(loc, maskedSrc, mask);
613 
614   Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter);
615   Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter);
616   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, dstType, isOne, one, zero);
617   return success();
618 }
619 
620 //===----------------------------------------------------------------------===//
621 // TypeCastingOpPattern
622 //===----------------------------------------------------------------------===//
623 
624 template <typename Op, typename SPIRVOp>
625 LogicalResult TypeCastingOpPattern<Op, SPIRVOp>::matchAndRewrite(
626     Op op, typename Op::Adaptor adaptor,
627     ConversionPatternRewriter &rewriter) const {
628   assert(adaptor.getOperands().size() == 1);
629   auto srcType = adaptor.getOperands().front().getType();
630   auto dstType =
631       this->getTypeConverter()->convertType(op.getResult().getType());
632   if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
633     return failure();
634   if (dstType == srcType) {
635     // Due to type conversion, we are seeing the same source and target type.
636     // Then we can just erase this operation by forwarding its operand.
637     rewriter.replaceOp(op, adaptor.getOperands().front());
638   } else {
639     rewriter.template replaceOpWithNewOp<SPIRVOp>(op, dstType,
640                                                   adaptor.getOperands());
641   }
642   return success();
643 }
644 
645 //===----------------------------------------------------------------------===//
646 // CmpIOpBooleanPattern
647 //===----------------------------------------------------------------------===//
648 
649 LogicalResult CmpIOpBooleanPattern::matchAndRewrite(
650     arith::CmpIOp op, OpAdaptor adaptor,
651     ConversionPatternRewriter &rewriter) const {
652   Type operandType = op.getLhs().getType();
653   if (!isBoolScalarOrVector(operandType))
654     return failure();
655 
656   switch (op.getPredicate()) {
657 #define DISPATCH(cmpPredicate, spirvOp)                                        \
658   case cmpPredicate:                                                           \
659     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
660                                          adaptor.getLhs(), adaptor.getRhs());  \
661     return success();
662 
663     DISPATCH(arith::CmpIPredicate::eq, spirv::LogicalEqualOp);
664     DISPATCH(arith::CmpIPredicate::ne, spirv::LogicalNotEqualOp);
665 
666 #undef DISPATCH
667   default:;
668   }
669   return failure();
670 }
671 
672 //===----------------------------------------------------------------------===//
673 // CmpIOpPattern
674 //===----------------------------------------------------------------------===//
675 
676 LogicalResult
677 CmpIOpPattern::matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
678                                ConversionPatternRewriter &rewriter) const {
679   Type operandType = op.getLhs().getType();
680   if (isBoolScalarOrVector(operandType))
681     return failure();
682 
683   switch (op.getPredicate()) {
684 #define DISPATCH(cmpPredicate, spirvOp)                                        \
685   case cmpPredicate:                                                           \
686     if (spirvOp::template hasTrait<OpTrait::spirv::UnsignedOp>() &&            \
687         operandType != this->getTypeConverter()->convertType(operandType)) {   \
688       return op.emitError(                                                     \
689           "bitwidth emulation is not implemented yet on unsigned op");         \
690     }                                                                          \
691     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
692                                          adaptor.getLhs(), adaptor.getRhs());  \
693     return success();
694 
695     DISPATCH(arith::CmpIPredicate::eq, spirv::IEqualOp);
696     DISPATCH(arith::CmpIPredicate::ne, spirv::INotEqualOp);
697     DISPATCH(arith::CmpIPredicate::slt, spirv::SLessThanOp);
698     DISPATCH(arith::CmpIPredicate::sle, spirv::SLessThanEqualOp);
699     DISPATCH(arith::CmpIPredicate::sgt, spirv::SGreaterThanOp);
700     DISPATCH(arith::CmpIPredicate::sge, spirv::SGreaterThanEqualOp);
701     DISPATCH(arith::CmpIPredicate::ult, spirv::ULessThanOp);
702     DISPATCH(arith::CmpIPredicate::ule, spirv::ULessThanEqualOp);
703     DISPATCH(arith::CmpIPredicate::ugt, spirv::UGreaterThanOp);
704     DISPATCH(arith::CmpIPredicate::uge, spirv::UGreaterThanEqualOp);
705 
706 #undef DISPATCH
707   }
708   return failure();
709 }
710 
711 //===----------------------------------------------------------------------===//
712 // CmpFOpPattern
713 //===----------------------------------------------------------------------===//
714 
715 LogicalResult
716 CmpFOpPattern::matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
717                                ConversionPatternRewriter &rewriter) const {
718   switch (op.getPredicate()) {
719 #define DISPATCH(cmpPredicate, spirvOp)                                        \
720   case cmpPredicate:                                                           \
721     rewriter.replaceOpWithNewOp<spirvOp>(op, op.getResult().getType(),         \
722                                          adaptor.getLhs(), adaptor.getRhs());  \
723     return success();
724 
725     // Ordered.
726     DISPATCH(arith::CmpFPredicate::OEQ, spirv::FOrdEqualOp);
727     DISPATCH(arith::CmpFPredicate::OGT, spirv::FOrdGreaterThanOp);
728     DISPATCH(arith::CmpFPredicate::OGE, spirv::FOrdGreaterThanEqualOp);
729     DISPATCH(arith::CmpFPredicate::OLT, spirv::FOrdLessThanOp);
730     DISPATCH(arith::CmpFPredicate::OLE, spirv::FOrdLessThanEqualOp);
731     DISPATCH(arith::CmpFPredicate::ONE, spirv::FOrdNotEqualOp);
732     // Unordered.
733     DISPATCH(arith::CmpFPredicate::UEQ, spirv::FUnordEqualOp);
734     DISPATCH(arith::CmpFPredicate::UGT, spirv::FUnordGreaterThanOp);
735     DISPATCH(arith::CmpFPredicate::UGE, spirv::FUnordGreaterThanEqualOp);
736     DISPATCH(arith::CmpFPredicate::ULT, spirv::FUnordLessThanOp);
737     DISPATCH(arith::CmpFPredicate::ULE, spirv::FUnordLessThanEqualOp);
738     DISPATCH(arith::CmpFPredicate::UNE, spirv::FUnordNotEqualOp);
739 
740 #undef DISPATCH
741 
742   default:
743     break;
744   }
745   return failure();
746 }
747 
748 //===----------------------------------------------------------------------===//
749 // CmpFOpNanKernelPattern
750 //===----------------------------------------------------------------------===//
751 
752 LogicalResult CmpFOpNanKernelPattern::matchAndRewrite(
753     arith::CmpFOp op, OpAdaptor adaptor,
754     ConversionPatternRewriter &rewriter) const {
755   if (op.getPredicate() == arith::CmpFPredicate::ORD) {
756     rewriter.replaceOpWithNewOp<spirv::OrderedOp>(op, adaptor.getLhs(),
757                                                   adaptor.getRhs());
758     return success();
759   }
760 
761   if (op.getPredicate() == arith::CmpFPredicate::UNO) {
762     rewriter.replaceOpWithNewOp<spirv::UnorderedOp>(op, adaptor.getLhs(),
763                                                     adaptor.getRhs());
764     return success();
765   }
766 
767   return failure();
768 }
769 
770 //===----------------------------------------------------------------------===//
771 // CmpFOpNanNonePattern
772 //===----------------------------------------------------------------------===//
773 
774 LogicalResult CmpFOpNanNonePattern::matchAndRewrite(
775     arith::CmpFOp op, OpAdaptor adaptor,
776     ConversionPatternRewriter &rewriter) const {
777   if (op.getPredicate() != arith::CmpFPredicate::ORD &&
778       op.getPredicate() != arith::CmpFPredicate::UNO)
779     return failure();
780 
781   Location loc = op.getLoc();
782 
783   Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getLhs());
784   Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, adaptor.getRhs());
785 
786   Value replace = rewriter.create<spirv::LogicalOrOp>(loc, lhsIsNan, rhsIsNan);
787   if (op.getPredicate() == arith::CmpFPredicate::ORD)
788     replace = rewriter.create<spirv::LogicalNotOp>(loc, replace);
789 
790   rewriter.replaceOp(op, replace);
791   return success();
792 }
793 
794 //===----------------------------------------------------------------------===//
795 // SelectOpPattern
796 //===----------------------------------------------------------------------===//
797 
798 LogicalResult
799 SelectOpPattern::matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
800                                  ConversionPatternRewriter &rewriter) const {
801   rewriter.replaceOpWithNewOp<spirv::SelectOp>(op, adaptor.getCondition(),
802                                                adaptor.getTrueValue(),
803                                                adaptor.getFalseValue());
804   return success();
805 }
806 
807 //===----------------------------------------------------------------------===//
808 // Pattern Population
809 //===----------------------------------------------------------------------===//
810 
811 void mlir::arith::populateArithmeticToSPIRVPatterns(
812     SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) {
813   // clang-format off
814   patterns.add<
815     ConstantCompositeOpPattern,
816     ConstantScalarOpPattern,
817     spirv::ElementwiseOpPattern<arith::AddIOp, spirv::IAddOp>,
818     spirv::ElementwiseOpPattern<arith::SubIOp, spirv::ISubOp>,
819     spirv::ElementwiseOpPattern<arith::MulIOp, spirv::IMulOp>,
820     spirv::ElementwiseOpPattern<arith::DivUIOp, spirv::UDivOp>,
821     spirv::ElementwiseOpPattern<arith::DivSIOp, spirv::SDivOp>,
822     spirv::ElementwiseOpPattern<arith::RemUIOp, spirv::UModOp>,
823     RemSIOpGLSLPattern, RemSIOpOCLPattern,
824     BitwiseOpPattern<arith::AndIOp, spirv::LogicalAndOp, spirv::BitwiseAndOp>,
825     BitwiseOpPattern<arith::OrIOp, spirv::LogicalOrOp, spirv::BitwiseOrOp>,
826     XOrIOpLogicalPattern, XOrIOpBooleanPattern,
827     spirv::ElementwiseOpPattern<arith::ShLIOp, spirv::ShiftLeftLogicalOp>,
828     spirv::ElementwiseOpPattern<arith::ShRUIOp, spirv::ShiftRightLogicalOp>,
829     spirv::ElementwiseOpPattern<arith::ShRSIOp, spirv::ShiftRightArithmeticOp>,
830     spirv::ElementwiseOpPattern<arith::NegFOp, spirv::FNegateOp>,
831     spirv::ElementwiseOpPattern<arith::AddFOp, spirv::FAddOp>,
832     spirv::ElementwiseOpPattern<arith::SubFOp, spirv::FSubOp>,
833     spirv::ElementwiseOpPattern<arith::MulFOp, spirv::FMulOp>,
834     spirv::ElementwiseOpPattern<arith::DivFOp, spirv::FDivOp>,
835     spirv::ElementwiseOpPattern<arith::RemFOp, spirv::FRemOp>,
836     TypeCastingOpPattern<arith::ExtUIOp, spirv::UConvertOp>, ExtUII1Pattern,
837     TypeCastingOpPattern<arith::ExtSIOp, spirv::SConvertOp>,
838     TypeCastingOpPattern<arith::ExtFOp, spirv::FConvertOp>,
839     TypeCastingOpPattern<arith::TruncIOp, spirv::SConvertOp>, TruncII1Pattern,
840     TypeCastingOpPattern<arith::TruncFOp, spirv::FConvertOp>,
841     TypeCastingOpPattern<arith::UIToFPOp, spirv::ConvertUToFOp>, UIToFPI1Pattern,
842     TypeCastingOpPattern<arith::SIToFPOp, spirv::ConvertSToFOp>,
843     TypeCastingOpPattern<arith::FPToSIOp, spirv::ConvertFToSOp>,
844     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
845     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
846     CmpIOpBooleanPattern, CmpIOpPattern,
847     CmpFOpNanNonePattern, CmpFOpPattern,
848     SelectOpPattern,
849 
850     spirv::ElementwiseOpPattern<arith::MaxFOp, spirv::GLSLFMaxOp>,
851     spirv::ElementwiseOpPattern<arith::MaxSIOp, spirv::GLSLSMaxOp>,
852     spirv::ElementwiseOpPattern<arith::MaxUIOp, spirv::GLSLUMaxOp>,
853     spirv::ElementwiseOpPattern<arith::MinFOp, spirv::GLSLFMinOp>,
854     spirv::ElementwiseOpPattern<arith::MinSIOp, spirv::GLSLSMinOp>,
855     spirv::ElementwiseOpPattern<arith::MinUIOp, spirv::GLSLUMinOp>
856   >(typeConverter, patterns.getContext());
857   // clang-format on
858 
859   // Give CmpFOpNanKernelPattern a higher benefit so it can prevail when Kernel
860   // capability is available.
861   patterns.add<CmpFOpNanKernelPattern>(typeConverter, patterns.getContext(),
862                                        /*benefit=*/2);
863 }
864 
865 //===----------------------------------------------------------------------===//
866 // Pass Definition
867 //===----------------------------------------------------------------------===//
868 
869 namespace {
870 struct ConvertArithmeticToSPIRVPass
871     : public ConvertArithmeticToSPIRVBase<ConvertArithmeticToSPIRVPass> {
872   void runOnOperation() override {
873     auto module = getOperation();
874     auto targetAttr = spirv::lookupTargetEnvOrDefault(module);
875     auto target = SPIRVConversionTarget::get(targetAttr);
876 
877     SPIRVTypeConverter::Options options;
878     options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
879     SPIRVTypeConverter typeConverter(targetAttr, options);
880 
881     RewritePatternSet patterns(&getContext());
882     arith::populateArithmeticToSPIRVPatterns(typeConverter, patterns);
883     populateFuncToSPIRVPatterns(typeConverter, patterns);
884     populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns);
885 
886     if (failed(applyPartialConversion(module, *target, std::move(patterns))))
887       signalPassFailure();
888   }
889 };
890 } // namespace
891 
892 std::unique_ptr<Pass> mlir::arith::createConvertArithmeticToSPIRVPass() {
893   return std::make_unique<ConvertArithmeticToSPIRVPass>();
894 }
895